Spark线性回归预测代码及注解

一、简介

线性回归使用数据的特征进行训练,以构建出一个模型(方程式)用来拟合训练的数据(最好事先判断一下这些特征和预测的结果能够真正存在线性关系)。然后使用该模型,输入相同的数量的特征,预测未来的走势。

二、对于LinearRegressionWithSGD和LinearRegression

在使用时,我们会发现,org.apache.spark.ml和org.apache.spark.mllib包下,都有关于线性回归的内容,分别对应的LinearRegression和LinearRegressionWithSGD,然后我对他们进行了比较。

按照官方说明,LinearRegressionWithSGD使用的随机梯度下降训练是没有正则化的线性回归模型的,所以不推荐使用。

我们在使用LinearRegression时,可以使用正则化,也就是 setElasticNetParam,弹性参数,用于调节L1和L2之间的比例,两种正则化比例加起来是1,详见后面正则化的设置,默认为0,只使用L2正则化(也就是岭回归),设置为1就是只用L1正则化。

在打印结果时,也能够看到很多推测结果。

...

val model = lr.fit(array(0))
println("模型截距:" + model.intercept)
println("模型权重:" + model.coefficients)
val summary = model.evaluate(array(1))
println("模型评价")
summary.residuals.show(5)
println("预测结果")
summary.predictions.show()
println("均方差:" + summary.meanSquaredError)
println("模型拟合度:" + summary.r2)
println("测试数据的条目数:" + summary.numInstances)

...

三、示例

该数据的第一列为标签(label),也可以理解成最终得到的值;而后面的8位都属于特征值,也就是用来建模的值。

-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
package com.linearRegression

import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.sql.SparkSession

object LinearRegressionDemo {
    def main(args: Array[String]): Unit = {
        val session = SparkSession.builder()
                .master("local")
                .appName("this.getClass.getSimpleName")
                .getOrCreate()
        import session.implicits._

        val dataset = session.read.textFile("src/main/resources/lpsa.txt")
        val parseData = dataset.map { line =>
            val str =line.split(",")
            val features = str(1).split(" ").map(_.toDouble)

            LabeledPoint(str(0).toDouble, Vectors.dense(features))
        }

        val array = parseData.randomSplit(Array(0.8, 0.2), 3)

        val linearRegression = new LinearRegression()
                .setLabelCol("label")
                .setFeaturesCol("features")
                .setTol(0.001)
                .setMaxIter(100)
                .setFitIntercept(true)

        val model = linearRegression.fit(array(0))

        println("权重: " + model.coefficients)
        println("截距:" + model.intercept)
        println("特征数:" + model.numFeatures)

        val summary = model.evaluate(array(1))
        val predictions = summary.predictions
        predictions.show(20)

        println("均方差:" + summary.meanSquaredError)
        println("平均绝对值误差:" + summary.meanAbsoluteError)
        println("测试数据的条目数:" + summary.numInstances)
        println("模型拟合度:" + summary.r2)

        session.stop()
    }
}

上述代码中的一些需要注意的地方

1、对于构建LinearRegression方程

val linearRegression = new LinearRegression()
                .setLabelCol("label")
                .setFeaturesCol("features")
                .setTol(0.001)
                .setMaxIter(100)
                .setFitIntercept(true)

在构建模型方程时,我们一般都确定了 setLabelCol("label")setFeaturesCol("features")的值,而 setTol(0.001)的值的设定,属于梯度下降的步长,或称学习率,我们可以使用更多的值带入尝试,比如0.1、0.003、0.009、0.0001……直到达到一个均方差最小的情况。

此外, setMaxIter(100)为迭代次数,可以尝试使用调大和小,直到达到一个均方差最小的情况。

对于最后的 setFitIntercept(true),其实就是截距,也就是最终绘制的方程中是否经过坐标轴(0,0)原点,设置为true就是允许不经过原点,所以一般设置为true。

Original: https://blog.csdn.net/qq_40579464/article/details/116571548
Author: 赵昕彧
Title: Spark线性回归预测代码及注解

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/634738/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球