逻辑回归应用场景
- 广告点击率预测
- 垃圾邮件识别
- 疾病诊断
- 金融欺诈检测
- 虚假账号检测
从上面的例子可以看出一个共同特点:它们都涉及两个类别之间的判断。逻辑回归正是解决二分类问题的首选工具。
逻辑回归原理
要掌握逻辑回归,必须理解两个关键点:
- 逻辑回归的输入是什么
- 如何解读逻辑回归的输出
输入函数
逻辑回归的输入是线性回归的结果。
激活函数
Sigmoid 函数:
判决标准:
- 将回归结果输入 Sigmoid 函数
- 输出结果:介于 [0, 1] 区间的一个概率值,默认阈值为 0.5
逻辑回归的最终分类由属于某类别的概率值决定。这个类别默认标记为 1,另一个类别标记为 0。
输出结果解读:假设有两个类别 A 和 B,概率值表示属于类别 A(1)的概率。如果一个样本输入逻辑回归输出 0.55,这个概率超过 0.5,意味着训练或预测结果为类别 A(1)。反之,如果结果为 0.3,则训练或预测结果为类别 B(0)。
逻辑回归的阈值是可以调整的。例如,如果将阈值设为 0.6,则输出 0.55 会被归类为类别 B。
损失与优化
逻辑回归中的损失称为对数似然损失。公式为:
按类别分开:
其中 Y 是真实值,hθ(x) 是预测值。
如何理解单独的表达式?这需要通过对数函数的图像来理解。
在所有情况下,我们都希望损失函数值尽可能小。
分情况讨论,对应的损失函数值:
- 当 y=1 时,希望 hθ(x) 尽可能大
- 当 y=0 时,希望 hθ(x) 尽可能小
- 合并后的完整损失函数
优化逻辑
同样使用梯度下降优化算法来降低损失函数值,更新逻辑回归算法的权重参数,以增加原本属于类别 1 的样本的概率,降低原本属于类别 0 的样本的概率。
案例实战
数据准备
wget https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv -O pima.csv
代码实现
package icu.wzk.logic
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}
object LogicTest {
def main(args: Array[String]): Unit = {
// ① 本地模式演示,生产环境请修改 master
val conf = new SparkConf()
.setAppName("LogisticRegression-RDD")
.setMaster("local[*]")
val sc = new SparkContext(conf)
sc.setLogLevel("WARN")
val raw = sc.textFile("pima.csv")
val points = raw.map { line =>
val cols = line.split(",").map(_.toDouble)
LabeledPoint(cols(8), Vectors.dense(cols.slice(0, 8)))
}.cache()
// ③ 训练集与测试集划分
val Array(train, test) = points.randomSplit(Array(0.8, 0.2), seed = 42)
// ④ 使用 LR+SGD 训练,迭代 100 次
val model = LogisticRegressionWithSGD.train(train, numIterations = 100)
// ⑤ 预测并计算简单准确率
val predictAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = predictAndLabel.filter { case (p, l) => p == l }.count().toDouble / test.count()
predictAndLabel.foreach { case (p, l) => println(s"pred=$p\tlabel=$l") }
println(f"accuracy = $accuracy%.4f")
sc.stop()
}
}
代码说明:
- 从本地文件 pima.csv 读取每一行数据。
- 每行按逗号分割,转换为数组
cols。 - 假设每行有 9 个值:前 8 个是特征,第 9 个(
cols(8))是标签。 - 构造
LabeledPoint对象,这是 Spark MLlib 中训练样本的格式(包含特征和标签)。 .cache()将数据缓存在内存中,以加快后续训练速度。- 将整个数据集随机划分为 80%(训练)+ 20%(测试)。
- 使用固定的随机种子 42,以确保多次运行结果一致。
- 使用
LogisticRegressionWithSGD(基于随机梯度下降的逻辑回归)在训练集上进行训练。 - 设置迭代次数为 100;模型训练 100 次迭代步骤以接近最优解。
- 对测试集中的每个样本进行预测,返回(预测值,实际标签)元组。
- 使用简单相等比较来判断预测是否正确。
- 计算正确预测数与测试样本总数的比值,得到预测准确率。
- 将每个预测结果和对应的标签输出到终端,供人工对比。
- 以 4 位小数的形式打印最终准确率。
最终准确率约为:76.62%