Big Data 200 - Decision Tree Information Gain Detailed
TL;DR
- Scenario: Use information entropy/information gain to explain why decision tree selects certain column for splitting, and use Python to reproduce “best split column”.
- Conclusion: Maximizing information gain ⇔ minimizing weighted average of child node entropy; ID3 uses greedy recursive tree building but tends to prefer features with many values.
- Output: Information gain derivation context + best split/per-column split/recursive tree building code skeleton + common pitfall quick reference.
Information Gain
The ultimate optimization goal of decision tree is to make total impurity of leaf nodes lowest, i.e., corresponding impurity measure metric lowest.
At the same time, we know that global optimal tree cannot be obtained simply and efficiently, therefore we still need to use local optimal method to guide modeling process, and through setting optimization conditions, gradually approach most likely global optimal solution under local optimal conditions at each step.
Under guidance of information entropy metric, decision tree generation’s local optimal condition is also very easy to understand: i.e., when selecting attribute test condition to split a certain node (dataset), try to select feature that makes child nodes’ information entropy smallest.
In other words, require parent node information entropy and child node entropy difference to be maximum.
Can also use formula:
Information Gain = Parent node entropy - Weighted average of child node entropy
Decision tree induction algorithm usually selects test condition that maximizes gain, because for all test conditions Ent(D) is an invariant value, so maximizing gain is equivalent to minimizing weighted average of branch nodes’ impurity measure. Finally, when selecting entropy as formula’s impurity measure, entropy difference is the so-called “information gain”.
Split Dataset
Classification algorithms need to not only measure information entropy, but also effectively split dataset. Information entropy is an important metric for measuring dataset purity, while dataset splitting is a key step in building decision tree.
After understanding how to calculate information entropy, we can use information gain method to evaluate dataset splitting quality. Specific steps:
- Calculate total information entropy of current dataset
- For each candidate feature, do:
- Split dataset into several subsets according to possible values of that feature
- Calculate information entropy of each subset
- Calculate weighted average information entropy after splitting
- Calculate information gain of that feature (total information entropy - weighted average information entropy)
Best Split Function for Dataset
Maximum criterion for splitting dataset is selecting maximum information gain, i.e., direction of fastest information decline.
Through manual calculation, we know:
- Column 0 information gain is 0.42, Column 1 information gain is 0.17
- So we should select Column 0 to split dataset
In Python, can output information gain for each column through following code:
# Define information entropy
def calEnt(dataSet):
n = dataSet.shape[0] # Total rows in dataset
iset = dataSet.iloc[:, -1].value_counts() # Count all label categories
p = iset / n # Count proportion of each label category
ent = (-p * np.log2(p)).sum() # Calculate information entropy
return ent
# Select optimal column for splitting
def bestSplit(dataSet):
baseEnt = calEnt(dataSet) # Calculate original entropy
bestGain = 0 # Initialize information gain
axis = -1 # Initialize best split column
for i in range(dataSet.shape[1] - 1): # Loop through each column of features
levels = dataSet.iloc[:, i].value_counts().index # Extract all values of current column
ents = 0 # Initialize child node information entropy
for j in levels: # Loop through each value of current column
childSet = dataSet[dataSet.iloc[:, i] == j] # Subset dataframe of a certain child node
ent = calEnt(childSet) # Calculate information entropy of a certain child node
ents += (childSet.shape[0] / dataSet.shape[0]) * ent # Calculate information entropy of current column
print('Column {} information entropy is {}'.format(i, ents))
infoGain = baseEnt - ents # Calculate information gain of current column
print('Column {} information gain is {}\n'.format(i, infoGain))
if infoGain > bestGain:
bestGain = infoGain # Select maximum information gain
axis = i # Index of column with maximum information gain
print("Column {} is optimal split column".format(axis))
return axis
Split Dataset by Given Column
Through best split function returning index of best split column, we can build a function to split dataset by given column.
"""
Function: Split dataset by given column
Parameters:
dataSet: Original dataset
axis: Specified column index
value: Specified attribute value
return: redataSet: Dataset after splitting by specified column index and attribute value
"""
def mySplit(dataSet, axis, value):
col = dataSet.columns[axis]
redataSet = dataSet.loc[dataSet[col] == value, :].drop(col, axis=1)
return redataSet
Decision Tree Generation
Now we have learned sub-function modules needed to construct decision tree algorithm from dataset. Its working principle:
- Get original dataset, then split based on best attribute value, since features may have more than two values, there may be dataset splits with more than two branches
- After first split, dataset is passed down to next node in tree branch
- At new node, we can split data again, so we can use recursive principle to process dataset.
Recursion stop conditions:
- Program traverses all attributes for splitting dataset
- All instances under each branch have same classification
- Current node contains empty sample set, cannot split
In case 2, we mark current node as leaf node, set its category to category with most samples in that node. Any data reaching leaf node must belong to leaf node’s classification. In case 3, also mark current node as leaf node, but set its category to category with most samples in its parent node.
ID3 Algorithm
ID3 algorithm prototype was first proposed by J.R. Quinlan in his 1986 doctoral thesis “Induction of Decision Trees”, this is first truly complete and systematic decision tree learning algorithm. The algorithm uses information gain criterion from information theory, adopts top-down greedy strategy to recursively build decision tree, has solid theoretical foundation and simple intuitive implementation, therefore widely applied in machine learning field.
Core idea of ID3 algorithm: At each non-leaf node of decision tree, select attribute with maximum information gain as splitting criterion.
Core of ID3 algorithm is applying information gain criterion to select features at each node of decision tree, recursively build decision tree. Specific method:
- From root node, calculate information gain for all possible features at node.
- Select feature with maximum information gain as node’s feature, establish child nodes from different values of this feature.
- Then call above method on child nodes to build decision tree.
- Until all features’ information gain is very small or no features can be selected, finally get a decision tree.
Specific Python code implementation:
"""
Function: Split dataset based on maximum information gain, recursively build decision tree
Parameters:
dataSet: Original dataset (last column is label)
return: myTree: Tree in dictionary form
"""
def createTree(dataSet):
featlist = list(dataSet.columns) # Extract all columns of dataset
classlist = dataSet.iloc[:, -1].value_counts() # Get last column class labels
# Check if most label count equals dataset row count, or dataset has only one column
if classlist.iloc[0] == dataSet.shape[0] or dataSet.shape[1] == 1:
return classlist.index[0] # If yes, return class label
axis = bestSplit(dataSet) # Determine index of current best split column
bestfeat = featlist[axis] # Get feature corresponding to this index
myTree = {bestfeat: {}} # Store tree info using nested dictionary
del featlist[axis] # Delete current feature
valuelist = set(dataSet.iloc[:, axis]) # Extract all attribute values of best split column
# Recursively build tree for each attribute value
for value in valuelist:
myTree[bestfeat][value] = createTree(mySplit(dataSet, axis, value))
return myTree
ID3 Algorithm Limitations
ID3 algorithm’s limitations mainly come from local optimal condition, i.e., information gain calculation method. Main limitations:
- Features with higher branching degree (more classification levels) often have smaller total child node information entropy. ID3 splits by certain column, some columns’ classifications may not have good enough indication for result. In extreme case, using ID as split field, each classification purity is 100%, so such classification method has no benefit.
- Cannot directly handle continuous variables. To use ID3 for continuous variables, need to first discretize continuous variables.
- Relatively sensitive to missing values, need to handle missing values before using ID3.
- No pruning setting, easily leads to overfitting, i.e., performs well on training set, poorly on test set.
Many optimization measures for ID3 ultimately constitute core content of C4.5 algorithm.
C4.5 Algorithm
C4.5 algorithm is similar to ID3 algorithm. C4.5 improves ID3 algorithm. C4.5 uses information gain ratio criterion to select features during generation.
Error Quick Reference
| Symptom | Root Cause | Diagnosis | Fix |
|---|---|---|---|
| Entropy calculation gets nan or inf | log2(0) when probability is 0 | Check if p appears 0 in calEnt() | Filter 0 probability: p = p[p>0] before calculating; or add tiny epsilon |
| Information gain all 0 or very small | Label column selected wrong, data incorrectly filtered causing unchanged entropy | Check if dataSet.iloc[:, -1] truly is label; check if subsets empty | Clarify label column; check childSet.shape[0]; skip empty subsets |
| bestSplit output column unstable | set() traversal unordered, value_counts().index order affects print appearance | Source of levels, valuelist | Use sorted(levels) for determinism; also use sorting for tree branches |
| drop(…, axis=1) reports parameter error/warning | Pandas version difference or parameter habit inconsistency | Fixed writing at drop(col, axis=1) | drop(columns=[col]) |
| Recursive tree building results abnormally deep/overfitting | No pruning, many discrete values causing branch explosion | Branch count per layer in createTree() | Introduce min sample count/max depth/info gain threshold; or switch to C4.5/pruning strategy |
| Taking “ID column/unique value column” as best feature | Information gain prefers features with many values | Observe column value count much larger than others | Use gain ratio (C4.5) or pre-filter high-cardinality identifier columns |
| Continuous value features cannot be handled | Code splits by discrete values levels = value_counts().index directly enumerates | Continuous features | First bin/discretize; or implement threshold splitting (sort + enumerate split points) |