A decision tree is a flow-chart-like supervised classification model and a representation of various solutions that are available to solve a given problem based on the possible outcomes of related choices.
Decision trees
- Require no assumptions regarding distribution of data
- Handles collinearity very easily
- Often doesn’t require data preprocessing
However, decision trees are not perfect, they can be particularly susceptible to overfitting.
Decisions are made at each node. At each, a single feature of the data is considered and decided on. By the end, any relevant features will have been resolved, resulting in the classification prediction.
Note: All examples in this post depict nodes that split into just two child nodes, because this is how they are implemented in modeling libraries, including scikit-learn. While it’s theoretically possible to split each node into more new groups, this is an impractical approach because of its computational complexity. A two-level binary split is functionally equivalent to a single-level three-way split, but much simpler in terms of computational demand.
Decisions and splits
In a decision tree, the data is split and passed down through decision nodes until reaching a leaf node. A decision node is split on the criterion that minimizes the impurity of the classes in their resulting children. Impurity refers to the degree of mixture with respect to class.
Nodes with low impurity have many more of one class than any other. A perfect split would have no impurity in the resulting child nodes; it would partition the data with each child containing only a single class. The worst possible split would have high impurity in the resulting child nodes; both of the child nodes would have equal numbers of each class.
When building a tree and growing a new node, a set of potential split points is generated for every predictor variable in the dataset. An algorithm is used to calculate the “purity” of the child nodes that would result from each split point of each variable. The feature and split point that generate the purest child nodes are selected to partition the data.
To determine the set of potential split points that will be considered for a variable, the algorithm first identifies what type of variable it is (categorical or continuous) and the range of values that exist for that variable.
A sample to grow a tree
Let’s grow a tree for the following data.
Color | Diameter (cm) | Fruit(target) |
Yellow | 3.5 | Apple |
Yellow | 7 | Apple |
Red | 2 | Grape |
Red | 2.5 | Grape |
Green | 4 | Grape |
Green | 3 | Apple |
Red | 6 | Apple |
First, the algorithm will consider splitting based on the categorical variable, Color. Since there are three categories, three options are considered. Note that “yes” always goes to the left and “no” to the right.
If the predictor variable is continuous, splits can be made anywhere along the range of numbers that exist in the data. Often the potential split points are determined by sorting the values for the feature and taking the mean of each consecutive pair of values.
However, there can be any number of split points, and fewer split points can be considered to save computational resources and time. It is very common, especially when dealing with very large ranges of numbers, to consider split points along percentiles of the distribution.
In the case of the fruit example above, Diameter is a continuous variable. One way a decision tree could handle this is to sort the values, identify average of consecutive values:
Then we examine splitting based on these identified means:
These are the six potential split points for the Diameter feature that were identified by the algorithm. Each option includes the children of that split, but we have not evaluated them yet.
Choosing splits: Gini impurity
Generally, splits are better when each resulting child node contains many more samples of one class than any other, like in example E above, because this means the split is effectively separating the classes, the primary job of the decision tree. In such cases, the child nodes are said to have low impurity (or high purity).
The decision tree algorithm determines the split that will result in the lowest impurity among the child nodes by performing a calculation. There are several possible metrics to use to determine the purity of a node and to decide how to split, including Gini impurity, entropy, information gain, and log loss.
The most straightforward is Gini impurity, and it’s also the default for the decision tree classifier in scikit-learn, so we will focus on that method here. The Gini impurity of a node is defined as:
where i = class,
P(i) = the probability of samples belonging to class i in a given node.
In the case of the fruits example, this becomes:
Gini impurity = 1 − P(apple)2 − P(grape)2
The Gini impurity is calculated for each child node of each potential split point. For example, there are nine split point options in the fruit example (A–I). The first potential split point is Color=red:
Calculate Gini impurity of each child node
For the “red=yes” child node:
Gini impurity = 1 – (1/3)2 – (2/3)2
= 1 – 0.111 – 0.444 = 0.445
And for the “red=no” child node:
Gini impurity = 1 – (3/4)2 – (1/4)2
= 1 – 0.5625 – 0.0625 = 0.375
Now there are two Gini impurity scores for split option A (whether or not the fruit is red), one for each child node. The final step is to combine these scores by taking their weighted average.
Calculate weighted average of Gini impurities
The weighted average accounts for the different number of samples represented in each Gini impurity score. The weighted average of the Gini impurities (Gi) is calculated as:
Gitotal = (3/7 * 0.445) + (4/7 * 0.375) = 0.405
Repeat this process for every split option
This same process is repeated for every split option. The fruit example has nine options (A–I):
Now there are nine Gini impurity scores ranging from 0.229 to 0.486. Since it’s a measure of impurity, the best scores are those closest to zero. In this case, that’s option E. The worst of the nine options is option C, because it doesn’t separate classes well. The worst possible Gini impurity score is 0.5, which would occur when each child node contains an equal number of each class.
Now that the algorithm has identified the potential split points and calculated the Gini impurity of the child nodes that result from them, it will grow the tree by selecting the split point with the lowest Gini impurity.
Grow the tree
In the case of the given example, the root node would use split option E to split the data. The left child becomes a leaf node, because it contains just one class. However, the right child still does not have class purity, so it becomes a new decision node (in the absence of some imposed stopping condition). The steps outlined above will repeat on the samples in this node to identify the feature and split value that would yield the best result.
Splitting would continue until all the leaves are pure or some imposed condition stops the splitting.
This process involves a lot of computation and this was only for a dataset with two features and seven observations. As with many machine learning algorithms, the theory and methodology behind decision trees are fairly straightforward and have been around for many years, but it wasn’t until the advent of powerful computing capabilities that these solutions were able to be put into practice.
Advantages and disadvantages of classification trees
Advantages:
- Require relatively few pre-processing steps
- Can work easily with all types of variables (continuous, categorical, discrete)
- Do not require normalization or scaling
- Decisions are transparent
- Not affected by extreme univariate values
Disadvantages:
- Can be computationally expensive relative to other algorithms
- Small changes in data can result in significant changes in predictions
Disclaimer: Like most of my posts, this content is intended solely for educational purposes and was created primarily for my personal reference. At times, I may rephrase original texts, and in some cases, I include materials such as graphs, equations, and datasets directly from their original sources.
I typically reference a variety of sources and update my posts whenever new or related information becomes available. For this particular post, the primary source was Google Advanced Data Analytics Professional Certificate program.