Search…

Decision trees

In this series (18 parts)
  1. What is machine learning: a map of the field
  2. Data, features, and the ML pipeline
  3. Linear regression
  4. Bias, variance, and the tradeoff
  5. Regularization: Ridge, Lasso, and ElasticNet
  6. Logistic regression and classification
  7. Evaluation metrics for classification
  8. Naive Bayes classifier
  9. K-Nearest Neighbors
  10. Decision trees
  11. Ensemble methods: Bagging and Random Forests
  12. Boosting: AdaBoost and Gradient Boosting
  13. Support Vector Machines
  14. K-Means clustering
  15. Dimensionality Reduction: PCA
  16. Gaussian mixture models and EM algorithm
  17. Model selection and cross-validation
  18. Feature engineering and selection

A decision tree splits your data into smaller and smaller groups by asking yes/no questions about the features. Each question divides the data at an internal node. When you reach a leaf, you get a prediction. That’s it. No matrix algebra, no gradient descent, no loss surface to navigate. Just a sequence of if-else rules learned from the data.

The simplicity is real, but so is the power. Decision trees are the building blocks of random forests and gradient-boosted trees, two of the most effective algorithms in applied ML. Understanding how a single tree works gives you the foundation for everything that follows.

Loan approval: a decision in plain language

A bank reviews loan applications. Here are 8 recent applicants.

ApplicantIncomeCredit scoreEmployedApproved
145k620YesNo
280k720YesYes
330k580NoNo
465k680YesYes
550k700YesYes
635k600NoNo
790k750YesYes
840k650YesNo

A decision tree asks yes/no questions to split these applicants into pure groups. It might first ask: “Is the credit score above 660?” Applicants above that threshold mostly get approved. Those below mostly get rejected. Then it refines each group with more questions.

A simple loan approval tree

graph TD
  A["Credit score > 660?"] -->|Yes| B["Income > 50k?"]
  A -->|No| C["Rejected"]
  B -->|Yes| D["Approved"]
  B -->|No| E["Borderline: check employment"]

The tree learns these questions automatically from data. It tries every possible feature and every possible split point, then picks the one that separates the classes most cleanly. The measure of “cleanly” is called impurity, and reducing impurity is the entire optimization objective.

Now let’s define impurity formally and see exactly how the algorithm chooses splits.

Recursive partitioning

Building a decision tree is a recursive process:

  1. Look at all your data at the current node.
  2. Try every feature and every possible split point.
  3. Pick the split that best separates the target classes (or best reduces prediction error for regression).
  4. Create two child nodes, one for each side of the split.
  5. Repeat on each child until a stopping condition is met.

The result is a binary tree where internal nodes are questions (“Is age > 30?”), branches are answers (yes/no), and leaves are predictions.

Recursive splitting: how a tree grows

graph TD
  A["Start: all data at root"] --> B["Find best feature and threshold"]
  B --> C["Split into left and right child"]
  C --> D["Left child: repeat"]
  C --> E["Right child: repeat"]
  D --> F["Stop if pure or max depth"]
  E --> F

Impurity measures

To pick the “best” split, you need a way to measure how mixed a node is. A node that contains only one class is pure. A node with a 50/50 mix is maximally impure. Two common measures handle this: Gini impurity and entropy.

Gini impurity

For a node with KK classes, Gini impurity is:

G=1k=1Kpk2G = 1 - \sum_{k=1}^{K} p_k^2

where pkp_k is the fraction of samples belonging to class kk.

If a node has 6 positive and 4 negative samples:

G=1(610)2(410)2=10.360.16=0.48G = 1 - \left(\frac{6}{10}\right)^2 - \left(\frac{4}{10}\right)^2 = 1 - 0.36 - 0.16 = 0.48

A pure node has G=0G = 0. The worst case for binary classification is G=0.5G = 0.5 (perfectly balanced classes).

Entropy and information gain

Entropy comes from information theory and measures the average surprise in a distribution:

H=k=1Kpklog2(pk)H = -\sum_{k=1}^{K} p_k \log_2(p_k)

For the same 6/4 split:

H=610log2(610)410log2(410)H = -\frac{6}{10}\log_2\left(\frac{6}{10}\right) - \frac{4}{10}\log_2\left(\frac{4}{10}\right) =0.6×(0.737)0.4×(1.322)= -0.6 \times (-0.737) - 0.4 \times (-1.322) =0.442+0.529=0.971= 0.442 + 0.529 = 0.971

Information gain is the reduction in entropy after a split:

IG=H(parent)ininH(childi)\text{IG} = H(\text{parent}) - \sum_{i} \frac{n_i}{n} H(\text{child}_i)

where nin_i is the number of samples in child ii and nn is the total.

Gini vs entropy

In practice, Gini and entropy almost always pick the same split. Gini is slightly faster to compute because it avoids the logarithm. Scikit-learn uses Gini by default. Entropy can be more sensitive to class balance in edge cases. Pick either one; the difference is rarely meaningful.

Gini vs Entropy: two ways to measure node impurity

graph TD
  subgraph Gini["Gini Impurity"]
      G1["Pure node: G = 0"]
      G2["50/50 split: G = 0.5"]
      G3["No logarithm needed"]
  end
  subgraph Entropy["Entropy"]
      E1["Pure node: H = 0"]
      E2["50/50 split: H = 1.0"]
      E3["Uses log base 2"]
  end
  N["Both pick the same split in most cases"]

The splitting algorithm

Here is the procedure for finding the best split at a node:

def find_best_split(X, y):
    best_score = float('inf')
    best_feature = None
    best_threshold = None

    for feature in range(X.shape[1]):
        thresholds = sorted(set(X[:, feature]))
        for i in range(len(thresholds) - 1):
            t = (thresholds[i] + thresholds[i + 1]) / 2
            left_mask = X[:, feature] <= t
            right_mask = ~left_mask

            score = weighted_impurity(y[left_mask], y[right_mask])
            if score < best_score:
                best_score = score
                best_feature = feature
                best_threshold = t

    return best_feature, best_threshold

For each feature, you sort its unique values and try thresholds between consecutive values. For each threshold, you compute the weighted impurity of the resulting left and right children. The split with the lowest weighted impurity wins.

The weighted impurity for a split is:

Impuritysplit=nLnGL+nRnGR\text{Impurity}_{\text{split}} = \frac{n_L}{n} G_L + \frac{n_R}{n} G_R

where nLn_L and nRn_R are the sizes of the left and right children.

Worked example 1: building a tree by hand with Gini impurity

Consider this small dataset. We want to predict whether someone buys a product based on Age and Income.

RowAgeIncomeBuys
12540kNo
23050kNo
33560kYes
44055kYes
54570kYes
65045kNo
72865kYes
83342kNo

We have 4 Yes and 4 No, so the root Gini is:

Groot=1(0.5)2(0.5)2=0.5G_{\text{root}} = 1 - (0.5)^2 - (0.5)^2 = 0.5

Try splitting on Income with threshold 52.5k (between 50k and 55k):

Left (\leq 52.5k): rows 1, 2, 6, 8 with labels [No, No, No, No]

GL=1(0/4)2(4/4)2=101=0G_L = 1 - (0/4)^2 - (4/4)^2 = 1 - 0 - 1 = 0

Right (>> 52.5k): rows 3, 4, 5, 7 with labels [Yes, Yes, Yes, Yes]

GR=1(4/4)2(0/4)2=110=0G_R = 1 - (4/4)^2 - (0/4)^2 = 1 - 1 - 0 = 0

Weighted impurity:

48(0)+48(0)=0\frac{4}{8}(0) + \frac{4}{8}(0) = 0

This is a perfect split. But let’s also check another split for comparison.

Try splitting on Age with threshold 32.5 (between 30 and 33):

Left (\leq 32.5): rows 1, 2, 7 with labels [No, No, Yes]

GL=1(1/3)2(2/3)2=10.1110.444=0.444G_L = 1 - (1/3)^2 - (2/3)^2 = 1 - 0.111 - 0.444 = 0.444

Right (>> 32.5): rows 3, 4, 5, 6, 8 with labels [Yes, Yes, Yes, No, No]

GR=1(3/5)2(2/5)2=10.360.16=0.48G_R = 1 - (3/5)^2 - (2/5)^2 = 1 - 0.36 - 0.16 = 0.48

Weighted impurity:

38(0.444)+58(0.48)=0.167+0.3=0.467\frac{3}{8}(0.444) + \frac{5}{8}(0.48) = 0.167 + 0.3 = 0.467

The Income split (0.0) beats the Age split (0.467), so Income \leq 52.5k is our root split.

Since both children are already pure, no further splitting is needed. Here is the resulting tree:

graph TD
  A["Income ≤ 52.5k?"]
  A -->|Yes| B["🔴 No<br/>4/4 No"]
  A -->|No| C["🟢 Yes<br/>4/4 Yes"]
  style B fill:#ffcccc,stroke:#cc0000,color:#000
  style C fill:#ccffcc,stroke:#00cc00,color:#000

Decision boundary regions for a 2D classification problem

That was a clean example. Real data is messier, and you rarely get a single perfect split. In those cases, you keep splitting each child node recursively.

Worked example 2: information gain calculation

Let’s use a slightly different dataset to compare splits using entropy.

RowFeature AFeature BLabel
110+
211+
301+
400-
510-
600-

The root has 3 positive and 3 negative samples.

Root entropy:

Hroot=36log23636log236H_{\text{root}} = -\frac{3}{6}\log_2\frac{3}{6} - \frac{3}{6}\log_2\frac{3}{6} =0.5×(1)0.5×(1)=1.0= -0.5 \times (-1) - 0.5 \times (-1) = 1.0

Maximum entropy for binary classification. Totally mixed.

Split on Feature A (A = 1 vs A = 0):

Left (A = 1): rows 1, 2, 5 with labels [+, +, -]

HL=23log22313log213H_L = -\frac{2}{3}\log_2\frac{2}{3} - \frac{1}{3}\log_2\frac{1}{3} =0.667×(0.585)0.333×(1.585)=0.390+0.528=0.918= -0.667 \times (-0.585) - 0.333 \times (-1.585) = 0.390 + 0.528 = 0.918

Right (A = 0): rows 3, 4, 6 with labels [+, -, -]

HR=13log21323log223H_R = -\frac{1}{3}\log_2\frac{1}{3} - \frac{2}{3}\log_2\frac{2}{3} =0.528+0.390=0.918= 0.528 + 0.390 = 0.918

Information gain for Feature A:

IG(A)=1.036(0.918)36(0.918)=1.00.918=0.082\text{IG}(A) = 1.0 - \frac{3}{6}(0.918) - \frac{3}{6}(0.918) = 1.0 - 0.918 = 0.082

Split on Feature B (B = 1 vs B = 0):

Left (B = 1): rows 2, 3 with labels [+, +]

HL=22log2220=0H_L = -\frac{2}{2}\log_2\frac{2}{2} - 0 = 0

Right (B = 0): rows 1, 4, 5, 6 with labels [+, -, -, -]

HR=14log21434log234H_R = -\frac{1}{4}\log_2\frac{1}{4} - \frac{3}{4}\log_2\frac{3}{4} =0.25×(2)0.75×(0.415)=0.5+0.311=0.811= -0.25 \times (-2) - 0.75 \times (-0.415) = 0.5 + 0.311 = 0.811

Information gain for Feature B:

IG(B)=1.026(0)46(0.811)\text{IG}(B) = 1.0 - \frac{2}{6}(0) - \frac{4}{6}(0.811) =1.000.541=0.459= 1.0 - 0 - 0.541 = 0.459

Feature B wins with information gain of 0.459 versus 0.082 for Feature A. Feature B produces a pure left child (both positive), which drops the overall entropy substantially. Feature A barely separates the classes at all.

This is the core idea: the best split is the one that gives you the most information about the target variable.

Building the tree recursively

Once you pick the best split for the root, you repeat the entire process on each child node. The left child becomes a new “mini-dataset,” and you search for the best split within it. Same for the right child. You keep going until one of these stopping conditions is met:

  • The node is pure (all samples have the same label).
  • The node has fewer samples than some minimum threshold.
  • The tree has reached a maximum depth.
  • No split improves impurity beyond a minimum threshold.

Without stopping conditions, the tree will keep splitting until every leaf contains exactly one sample. That’s a problem.

Overfitting and the bias-variance tradeoff

A deep, unrestricted tree will fit the training data perfectly. Every training example gets its own leaf. Training accuracy: 100%. But the tree has memorized the noise in the training data, and it will perform badly on new data.

This is the classic bias-variance tradeoff. A deep tree has low bias (it can represent complex decision boundaries) but high variance (small changes in the training data lead to completely different trees). A shallow tree has higher bias but lower variance.

You can spot overfitting when there is a large gap between training accuracy and validation accuracy. If your tree gets 99% on training but 72% on validation, it is almost certainly overfitting.

Pruning

Pruning reduces tree complexity to fight overfitting. It is a form of regularization, specific to tree-based models. There are two approaches.

Pruning: simplify the tree to fight overfitting

graph LR
  A["Full unpruned tree"] --> B["Many leaves, fits noise"]
  B --> C["Prune: remove weak splits"]
  C --> D["Simpler tree, fewer leaves"]
  D --> E["Better generalization"]

Pre-pruning (early stopping)

You set constraints before building the tree:

  • max_depth: limit how deep the tree can grow.
  • min_samples_split: require at least nn samples to attempt a split.
  • min_samples_leaf: require at least nn samples in every leaf.
  • min_impurity_decrease: only split if the impurity drops by at least some threshold.

These are hyperparameters you tune with cross-validation.

from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier(
    max_depth=5,
    min_samples_split=10,
    min_samples_leaf=4
)
tree.fit(X_train, y_train)

Pre-pruning is simple and fast. The downside is that it might stop splitting too early. A split that looks useless on its own might enable a very useful split one level deeper. Pre-pruning cannot see that.

Post-pruning (cost-complexity pruning)

Post-pruning builds the full tree first, then trims it back. The most common method is cost-complexity pruning, which defines a cost:

Rα(T)=R(T)+αTR_\alpha(T) = R(T) + \alpha \cdot |T|

where R(T)R(T) is the misclassification rate of tree TT, T|T| is the number of leaves, and α\alpha is a complexity parameter. A larger α\alpha penalizes bigger trees more heavily, resulting in more aggressive pruning.

You sweep over values of α\alpha and pick the one that minimizes validation error. Scikit-learn supports this directly:

path = tree.cost_complexity_pruning_path(X_train, y_train)
alphas = path.ccp_alphas

best_alpha = None
best_score = 0
for a in alphas:
    t = DecisionTreeClassifier(ccp_alpha=a)
    t.fit(X_train, y_train)
    score = t.score(X_val, y_val)
    if score > best_score:
        best_score = score
        best_alpha = a

Handling continuous vs categorical features

Continuous features are straightforward. Sort the values, consider thresholds between consecutive unique values. For a feature with mm unique values, you have m1m - 1 candidate thresholds.

Categorical features are trickier. If a feature has cc categories, there are 2c112^{c-1} - 1 possible binary partitions. For a feature with 10 categories, that is 511 possible splits. This gets expensive.

Some implementations (like scikit-learn) require you to one-hot encode categorical features. Others (like LightGBM) handle categories natively using an optimal algorithm that sorts categories by their average target value, then tries splits on this sorted order. This reduces the problem to c1c - 1 candidate splits, the same as a continuous feature.

⚠ One-hot encoding deep categorical features (like zip codes with thousands of values) can degrade tree performance. Consider target encoding or using a library with native categorical support in those cases.

Decision trees for regression

Everything above applies to classification. For regression, the idea is the same, but the impurity measure and prediction change.

Instead of Gini or entropy, you use mean squared error (MSE) as the splitting criterion:

MSE=1ni=1n(yiyˉ)2\text{MSE} = \frac{1}{n}\sum_{i=1}^{n}(y_i - \bar{y})^2

where yˉ\bar{y} is the mean of the target values in the node.

Each leaf predicts the mean of the target values that fall into it. The best split is the one that minimizes the weighted MSE of the children.

from sklearn.tree import DecisionTreeRegressor

reg_tree = DecisionTreeRegressor(max_depth=4)
reg_tree.fit(X_train, y_train)
predictions = reg_tree.predict(X_test)

Regression trees have the same overfitting problems as classification trees. A tree with enough depth will memorize every training target value. The same pruning techniques apply.

One important limitation: regression trees predict constant values within each leaf region. They cannot extrapolate beyond the range of training data. If your training targets range from 0 to 100, the tree will never predict 105. Keep this in mind when your data has trends that extend beyond the training range.

Strengths and limitations

Strengths:

  • ✓ Easy to interpret and visualize. You can show a decision tree to a non-technical stakeholder and they will understand it.
  • ✓ No feature scaling needed. Trees only care about the rank order of values, not their magnitude.
  • ✓ Handle both numerical and categorical data.
  • ✓ Capture non-linear relationships and feature interactions naturally.

Limitations:

  • ✗ Unstable. Small changes in data can produce completely different trees.
  • ✗ Greedy algorithm. Each split is locally optimal, not globally optimal.
  • ✗ Axis-aligned splits only. A single tree cannot efficiently represent diagonal decision boundaries.
  • ✗ Prone to overfitting without pruning or depth limits.

These limitations are exactly why ensemble methods exist. By combining many trees, you get the best of both worlds: the expressiveness of trees with the stability of averaging.

What comes next

A single decision tree is powerful but fragile. The natural next step is to combine many trees to create something more robust. In Ensemble Methods: Bagging and Random Forests, we build on everything here to show how averaging over hundreds of trees reduces variance while keeping bias low. If you understood how a single tree works, ensembles will click quickly.

Start typing to search across all content
navigate Enter open Esc close