Pre-pruning and Post-pruning

Decision tree has good classification ability on training set, but may not have good classification ability on unknown test set, leading to poor model generalization, possibly overfitting. To prevent overfitting, decision tree can be pruned, which is divided into pre-pruning and post-pruning.

Pre-pruning

Pre-pruning stops building decision tree early, for example, specify maximum tree depth as 3, then the trained decision tree height is 3. Pre-pruning mainly establishes some rules to limit decision tree growth, reduces overfitting risk, reduces tree building time, but may bring underfitting problem.

Post-pruning

Post-pruning is a global optimization method, starts pruning after decision tree is built. The process of post-pruning is to delete some subtrees, the class label of leaf node is determined by majority principle, that is, the class that most samples under this leaf node belong to is the label of this leaf node.

When choosing which subtrees to cut, can calculate error before cutting subtree and error after cutting subtree, if difference is not big, can cut the subtree.

Generally, post-pruning gets better results.

Algorithm Summary

  • Split Criterion: Choose split attribute and split point
  • Tree Growing: Recursively continue splitting on subsets
  • Pruning: Reduce overfitting: pre-pruning / post-pruning
  • Leaf Node Prediction: Classification tree: voting / probability; regression tree: mean / median

The differences between ID3, C4.5, and CART are mainly in split criteria, supported attribute types, tree structure, and pruning methods.

ID3

Disadvantages:

  • ID3 algorithm uses information gain as evaluation standard when choosing root node and branch attributes in internal nodes. The disadvantage of information gain is it tends to choose attributes with more values, in some cases these attributes may not provide much valuable information
  • ID3 algorithm can only construct decision tree for datasets with discrete descriptive attributes

Core Idea:

Information Gain: Choose attribute that maximizes entropy decrease (information gain) for splitting, only supports discrete attributes; continuous attributes need discretization first.

Algorithm Flow:

  • Calculate entropy H(D) of current dataset D
  • For each attribute a, calculate Gain(D,a)
  • Choose the one with maximum Gain to split, recursively generate subtree on subsets
  • Stop when attributes are exhausted or sample purity is sufficient

C4.5

Why is C4.5 better?

  • Use information gain rate to choose attributes
  • Can handle continuous numeric attributes
  • Uses a post-pruning method
  • Handles missing values

Advantages:

  • Generated classification rules are easy to understand, high accuracy

Disadvantages:

  • During tree construction, need multiple sequential scans and sorts of dataset, leading to algorithm inefficiency
  • Only suitable for datasets that can fit in memory, when training set is too large to fit in memory, program cannot run

CART

CART algorithm compared to C4.5’s classification method uses simplified binary tree model, and feature selection uses approximate Gini coefficient to simplify calculation.

Flow Points:

  • Enumerate all split points for each feature → calculate Gini decrease / mean squared error decrease
  • Choose “feature + split point” with maximum decrease for binary split
  • Stop when leaf node samples are less than threshold or purity meets stopping criterion
  • Get final subtree through cost complexity pruning

Decision Tree Case

package icu.wzk.logic


import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}


object LogicTest2 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local").setAppName("dt")
    val sc = new SparkContext(conf)
    sc.setLogLevel("warn")
    //Read dataset
    val labeledPointData = MLUtils.loadLibSVMFile(sc, "./data/dt.data")
    val trainTestData = labeledPointData.randomSplit(Array(0.8, 0.2), seed = 1)
    val trainData = trainTestData(0)
    val testData = trainTestData(1)
    //Train model
    val categoriFeatureMap = Map[Int, Int](0 -> 4, 1 -> 4, 2 -> 3, 3 -> 3)
    val model = DecisionTree.trainClassifier(trainData, 2,
      categoriFeatureMap, "entropy", 3, 32)
    //Predict
    val testRes = testData.map(data => {
      (model.predict(data.features), data.label)
    })
    testRes.take(10).foreach(println(_))
    //Evaluate
    val errorRate = testRes.filter(x => x._1 != x._2).count().toDouble /
      testData.count()
    println("错误率:" + errorRate)//if-else display
    println(model.toDebugString)
    sc.stop()
  }
}