Big Data 200 - Decision Tree Information Gain Detailed

TL;DR

  • Scenario: Use information entropy/information gain to explain why decision tree selects certain column for splitting, and use Python to reproduce “best split column”.
  • Conclusion: Maximizing information gain ⇔ minimizing weighted average of child node entropy; ID3 uses greedy recursive tree building but tends to prefer features with many values.
  • Output: Information gain derivation context + best split/per-column split/recursive tree building code skeleton + common pitfall quick reference.

Information Gain

The ultimate optimization goal of decision tree is to make total impurity of leaf nodes lowest, i.e., corresponding impurity measure metric lowest.

At the same time, we know that global optimal tree cannot be obtained simply and efficiently, therefore we still need to use local optimal method to guide modeling process, and through setting optimization conditions, gradually approach most likely global optimal solution under local optimal conditions at each step.

Under guidance of information entropy metric, decision tree generation’s local optimal condition is also very easy to understand: i.e., when selecting attribute test condition to split a certain node (dataset), try to select feature that makes child nodes’ information entropy smallest.

In other words, require parent node information entropy and child node entropy difference to be maximum.

Can also use formula:

Information Gain = Parent node entropy - Weighted average of child node entropy

Decision tree induction algorithm usually selects test condition that maximizes gain, because for all test conditions Ent(D) is an invariant value, so maximizing gain is equivalent to minimizing weighted average of branch nodes’ impurity measure. Finally, when selecting entropy as formula’s impurity measure, entropy difference is the so-called “information gain”.


Split Dataset

Classification algorithms need to not only measure information entropy, but also effectively split dataset. Information entropy is an important metric for measuring dataset purity, while dataset splitting is a key step in building decision tree.

After understanding how to calculate information entropy, we can use information gain method to evaluate dataset splitting quality. Specific steps:

  1. Calculate total information entropy of current dataset
  2. For each candidate feature, do:
    • Split dataset into several subsets according to possible values of that feature
    • Calculate information entropy of each subset
    • Calculate weighted average information entropy after splitting
    • Calculate information gain of that feature (total information entropy - weighted average information entropy)

Best Split Function for Dataset

Maximum criterion for splitting dataset is selecting maximum information gain, i.e., direction of fastest information decline.

Through manual calculation, we know:

  • Column 0 information gain is 0.42, Column 1 information gain is 0.17
  • So we should select Column 0 to split dataset

In Python, can output information gain for each column through following code:

# Define information entropy
def calEnt(dataSet):
    n = dataSet.shape[0]  # Total rows in dataset
    iset = dataSet.iloc[:, -1].value_counts()  # Count all label categories
    p = iset / n  # Count proportion of each label category
    ent = (-p * np.log2(p)).sum()  # Calculate information entropy
    return ent

# Select optimal column for splitting
def bestSplit(dataSet):
    baseEnt = calEnt(dataSet)  # Calculate original entropy
    bestGain = 0  # Initialize information gain
    axis = -1  # Initialize best split column

    for i in range(dataSet.shape[1] - 1):  # Loop through each column of features
        levels = dataSet.iloc[:, i].value_counts().index  # Extract all values of current column
        ents = 0  # Initialize child node information entropy

        for j in levels:  # Loop through each value of current column
            childSet = dataSet[dataSet.iloc[:, i] == j]  # Subset dataframe of a certain child node
            ent = calEnt(childSet)  # Calculate information entropy of a certain child node
            ents += (childSet.shape[0] / dataSet.shape[0]) * ent  # Calculate information entropy of current column

        print('Column {} information entropy is {}'.format(i, ents))
        infoGain = baseEnt - ents  # Calculate information gain of current column
        print('Column {} information gain is {}\n'.format(i, infoGain))

        if infoGain > bestGain:
            bestGain = infoGain  # Select maximum information gain
            axis = i  # Index of column with maximum information gain

    print("Column {} is optimal split column".format(axis))
    return axis

Split Dataset by Given Column

Through best split function returning index of best split column, we can build a function to split dataset by given column.

"""
Function: Split dataset by given column
Parameters:
dataSet: Original dataset
axis: Specified column index
value: Specified attribute value
return: redataSet: Dataset after splitting by specified column index and attribute value
"""
def mySplit(dataSet, axis, value):
    col = dataSet.columns[axis]
    redataSet = dataSet.loc[dataSet[col] == value, :].drop(col, axis=1)
    return redataSet

Decision Tree Generation

Now we have learned sub-function modules needed to construct decision tree algorithm from dataset. Its working principle:

  • Get original dataset, then split based on best attribute value, since features may have more than two values, there may be dataset splits with more than two branches
  • After first split, dataset is passed down to next node in tree branch
  • At new node, we can split data again, so we can use recursive principle to process dataset.

Recursion stop conditions:

  • Program traverses all attributes for splitting dataset
  • All instances under each branch have same classification
  • Current node contains empty sample set, cannot split

In case 2, we mark current node as leaf node, set its category to category with most samples in that node. Any data reaching leaf node must belong to leaf node’s classification. In case 3, also mark current node as leaf node, but set its category to category with most samples in its parent node.


ID3 Algorithm

ID3 algorithm prototype was first proposed by J.R. Quinlan in his 1986 doctoral thesis “Induction of Decision Trees”, this is first truly complete and systematic decision tree learning algorithm. The algorithm uses information gain criterion from information theory, adopts top-down greedy strategy to recursively build decision tree, has solid theoretical foundation and simple intuitive implementation, therefore widely applied in machine learning field.

Core idea of ID3 algorithm: At each non-leaf node of decision tree, select attribute with maximum information gain as splitting criterion.

Core of ID3 algorithm is applying information gain criterion to select features at each node of decision tree, recursively build decision tree. Specific method:

  • From root node, calculate information gain for all possible features at node.
  • Select feature with maximum information gain as node’s feature, establish child nodes from different values of this feature.
  • Then call above method on child nodes to build decision tree.
  • Until all features’ information gain is very small or no features can be selected, finally get a decision tree.

Specific Python code implementation:

"""
Function: Split dataset based on maximum information gain, recursively build decision tree
Parameters:
dataSet: Original dataset (last column is label)
return: myTree: Tree in dictionary form
"""
def createTree(dataSet):
    featlist = list(dataSet.columns)  # Extract all columns of dataset
    classlist = dataSet.iloc[:, -1].value_counts()  # Get last column class labels

    # Check if most label count equals dataset row count, or dataset has only one column
    if classlist.iloc[0] == dataSet.shape[0] or dataSet.shape[1] == 1:
        return classlist.index[0]  # If yes, return class label

    axis = bestSplit(dataSet)  # Determine index of current best split column
    bestfeat = featlist[axis]  # Get feature corresponding to this index
    myTree = {bestfeat: {}}  # Store tree info using nested dictionary

    del featlist[axis]  # Delete current feature
    valuelist = set(dataSet.iloc[:, axis])  # Extract all attribute values of best split column

    # Recursively build tree for each attribute value
    for value in valuelist:
        myTree[bestfeat][value] = createTree(mySplit(dataSet, axis, value))

    return myTree

ID3 Algorithm Limitations

ID3 algorithm’s limitations mainly come from local optimal condition, i.e., information gain calculation method. Main limitations:

  • Features with higher branching degree (more classification levels) often have smaller total child node information entropy. ID3 splits by certain column, some columns’ classifications may not have good enough indication for result. In extreme case, using ID as split field, each classification purity is 100%, so such classification method has no benefit.
  • Cannot directly handle continuous variables. To use ID3 for continuous variables, need to first discretize continuous variables.
  • Relatively sensitive to missing values, need to handle missing values before using ID3.
  • No pruning setting, easily leads to overfitting, i.e., performs well on training set, poorly on test set.

Many optimization measures for ID3 ultimately constitute core content of C4.5 algorithm.


C4.5 Algorithm

C4.5 algorithm is similar to ID3 algorithm. C4.5 improves ID3 algorithm. C4.5 uses information gain ratio criterion to select features during generation.


Error Quick Reference

SymptomRoot CauseDiagnosisFix
Entropy calculation gets nan or inflog2(0) when probability is 0Check if p appears 0 in calEnt()Filter 0 probability: p = p[p>0] before calculating; or add tiny epsilon
Information gain all 0 or very smallLabel column selected wrong, data incorrectly filtered causing unchanged entropyCheck if dataSet.iloc[:, -1] truly is label; check if subsets emptyClarify label column; check childSet.shape[0]; skip empty subsets
bestSplit output column unstableset() traversal unordered, value_counts().index order affects print appearanceSource of levels, valuelistUse sorted(levels) for determinism; also use sorting for tree branches
drop(…, axis=1) reports parameter error/warningPandas version difference or parameter habit inconsistencyFixed writing at drop(col, axis=1)drop(columns=[col])
Recursive tree building results abnormally deep/overfittingNo pruning, many discrete values causing branch explosionBranch count per layer in createTree()Introduce min sample count/max depth/info gain threshold; or switch to C4.5/pruning strategy
Taking “ID column/unique value column” as best featureInformation gain prefers features with many valuesObserve column value count much larger than othersUse gain ratio (C4.5) or pre-filter high-cardinality identifier columns
Continuous value features cannot be handledCode splits by discrete values levels = value_counts().index directly enumeratesContinuous featuresFirst bin/discretize; or implement threshold splitting (sort + enumerate split points)