逻辑回归应用场景

  • 广告点击率预测
  • 垃圾邮件识别
  • 疾病诊断
  • 金融欺诈检测
  • 虚假账号检测

从上面的例子可以看出一个共同特点:它们都涉及两个类别之间的判断。逻辑回归正是解决二分类问题的首选工具。

逻辑回归原理

要掌握逻辑回归,必须理解两个关键点:

  • 逻辑回归的输入是什么
  • 如何解读逻辑回归的输出

输入函数

逻辑回归的输入是线性回归的结果。

激活函数

Sigmoid 函数:

判决标准:

  • 将回归结果输入 Sigmoid 函数
  • 输出结果:介于 [0, 1] 区间的一个概率值,默认阈值为 0.5

逻辑回归的最终分类由属于某类别的概率值决定。这个类别默认标记为 1,另一个类别标记为 0。

输出结果解读:假设有两个类别 A 和 B,概率值表示属于类别 A(1)的概率。如果一个样本输入逻辑回归输出 0.55,这个概率超过 0.5,意味着训练或预测结果为类别 A(1)。反之,如果结果为 0.3,则训练或预测结果为类别 B(0)。

逻辑回归的阈值是可以调整的。例如,如果将阈值设为 0.6,则输出 0.55 会被归类为类别 B。

损失与优化

逻辑回归中的损失称为对数似然损失。公式为:

按类别分开:

其中 Y 是真实值,hθ(x) 是预测值。

如何理解单独的表达式?这需要通过对数函数的图像来理解。

在所有情况下,我们都希望损失函数值尽可能小。

分情况讨论,对应的损失函数值:

  • 当 y=1 时,希望 hθ(x) 尽可能大
  • 当 y=0 时,希望 hθ(x) 尽可能小
  • 合并后的完整损失函数

优化逻辑

同样使用梯度下降优化算法来降低损失函数值,更新逻辑回归算法的权重参数,以增加原本属于类别 1 的样本的概率,降低原本属于类别 0 的样本的概率。

案例实战

数据准备

wget https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv -O pima.csv

代码实现

package icu.wzk.logic

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}


object LogicTest {
  def main(args: Array[String]): Unit = {

    // ① 本地模式演示,生产环境请修改 master
    val conf = new SparkConf()
      .setAppName("LogisticRegression-RDD")
      .setMaster("local[*]")
    val sc = new SparkContext(conf)
    sc.setLogLevel("WARN")
    val raw = sc.textFile("pima.csv")
    val points = raw.map { line =>
      val cols = line.split(",").map(_.toDouble)
      LabeledPoint(cols(8), Vectors.dense(cols.slice(0, 8)))
    }.cache()

    // ③ 训练集与测试集划分
    val Array(train, test) = points.randomSplit(Array(0.8, 0.2), seed = 42)

    // ④ 使用 LR+SGD 训练,迭代 100 次
    val model = LogisticRegressionWithSGD.train(train, numIterations = 100)

    // ⑤ 预测并计算简单准确率
    val predictAndLabel = test.map(p => (model.predict(p.features), p.label))
    val accuracy = predictAndLabel.filter { case (p, l) => p == l }.count().toDouble / test.count()

    predictAndLabel.foreach { case (p, l) => println(s"pred=$p\tlabel=$l") }
    println(f"accuracy = $accuracy%.4f")

    sc.stop()
  }
}

代码说明:

  • 从本地文件 pima.csv 读取每一行数据。
  • 每行按逗号分割,转换为数组 cols
  • 假设每行有 9 个值:前 8 个是特征,第 9 个(cols(8))是标签。
  • 构造 LabeledPoint 对象,这是 Spark MLlib 中训练样本的格式(包含特征和标签)。
  • .cache() 将数据缓存在内存中,以加快后续训练速度。
  • 将整个数据集随机划分为 80%(训练)+ 20%(测试)。
  • 使用固定的随机种子 42,以确保多次运行结果一致。
  • 使用 LogisticRegressionWithSGD(基于随机梯度下降的逻辑回归)在训练集上进行训练。
  • 设置迭代次数为 100;模型训练 100 次迭代步骤以接近最优解。
  • 对测试集中的每个样本进行预测,返回(预测值,实际标签)元组。
  • 使用简单相等比较来判断预测是否正确。
  • 计算正确预测数与测试样本总数的比值,得到预测准确率。
  • 将每个预测结果和对应的标签输出到终端,供人工对比。
  • 以 4 位小数的形式打印最终准确率。

最终准确率约为:76.62%