Out-of-Bag Error in Random Forest [with example]

Do not miss this exclusive book on Binary Tree Problems. Get it now for free.

In this article, we have explored Out-of-Bag Error in Random Forest with an numeric example and sample implementation to plot Out-of-Bag Error.

Table of contents:

  1. Introduction to Random Forest
  2. Out-of-Bag Error in Random Forest
  3. Example of Out-of-Bag Error
  4. Code for Out-of-Bag Error

Introduction to Random Forest

Random Forest is one of the machine learning algorithms that use bootstrap aggregation. Random Forest aggregates the result of several decision trees. Decision Trees are known to work well when they have small depth otherwise they overfit. When they are used ensemble in Random Forests, this weakness of decision trees is mitigated.

Random Forest works by taking a random sample of small subsets of the data and applies a decision tree classification to them. The prediction of the Random Forest is then a combination of the individual prediction of the decision trees either by summing or taking a majority vote or any other suitable means of combining the results.

The sampling of random subsets (with replacement) of the training data is what is referred to as bagging. The idea is that the randomness in choosing the data fed to each decision tree will reduce the variance in the predictions from the random forest model.

Out-of-Bag Error in Random Forest

The out-of-bag error is the average error for each predicted outcome calculated using predictions from the trees that do not contain that data point in their respective bootstrap sample. This way, the Random Forest model is constantly being validated while being trained. Let us consider the jth decision tree \(DT_j\) that has been fitted on a subset of the sample data. For every training observation or sample \(z_i = (x_i , y_i)\) not in the sample subset of \(DT_j\) where \(x_i\) is the set of features and \(y_i\) is the target, we use \(DT_j\) to predict the outcome \(o_i\) for \(x_i\). The error can easily be computed as \(|o_i - y_i|\).
The out-of-bag error is thus the average value of this error across all decision trees.

Example of Out-of-Bag Error

The following is a simple example of how this works in practice. Consider this toy dataset which records if it rains given the temperature and humidity:

S/N Temperature Humidity Rained?
1 33 High No
2 18 Low No
3 27 Low Yes
4 20 High Yes
5 21 Low No
6 29 Low Yes
7 19 High Yes

Assume that a random forest ensemble consisting of 5 decision trees \(DT_1 ... DT_5\) is to be trained on the the dataset. Each tree will be trained on a random subset of the dataset. Assuming for \(DT_1\) that the randomly selected subset contains the first five samples of the dataset. Therefore, the last two samples 6 and 7 will be the out-of-bag samples on which \(DT_1\) will be validated. Continuing with the assumption, let the following table represent the prediction of each decision tree on each of its out-of-bag samples:

Tree Sample S/N Prediction Actual Error (abs)
DT1 6 No Yes 1
DT1 7 No Yes 1
DT2 2 No No 0
DT3 1 No No 0
DT3 2 Yes No 1
DT3 4 Yes Yes 0
DT4 2 Yes No 1
DT4 7 Yes Yes 1
DT5 3 Yes Yes 0
DT5 5 No No 0

From the above, the out-of-bag error is the average error which is 0.5.
We thus see that because only a subset of the decision trees in the ensemble is used in determining each error that is used to compute the out-of-bag score, it cannot be considered as accurate as a validation score on validation data. However, in cases such as this where the dataset is quite small and it is impossible to set aside a validation set, the out-of-bag error can prove to be a useful metric.

Code for Out-of-Bag Error

Lastly, we demonstrate the use of out-of-bag error in the estimation of a suitable value for the choice of n_estimators for a random forest model in scikit-learn library. This example is adapted from the documentation. The out-of-bag error is measured at the addition of each new tree during training. The resulting plot shows that the choice of 115 for n_estimators is optimal for the classifier (with 'sqrt' max_features) in this example.

import matplotlib.pyplot as plt

from collections import OrderedDict
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier

RANDOM_STATE = 123

# Generate a binary classification dataset.
X, y = make_classification(
    n_samples=500,
    n_features=25,
    n_clusters_per_class=1,
    n_informative=15,
    random_state=RANDOM_STATE,
)

# NOTE: Setting the `warm_start` construction parameter to `True` disables
# support for parallelized ensembles but is necessary for tracking the OOB
# error trajectory during training.
ensemble_clfs = [
    (
        "RandomForestClassifier, max_features='sqrt'",
        RandomForestClassifier(
            warm_start=True,
            oob_score=True,
            max_features="sqrt",
            random_state=RANDOM_STATE,
        ),
    ),
    (
        "RandomForestClassifier, max_features='log2'",
        RandomForestClassifier(
            warm_start=True,
            max_features="log2",
            oob_score=True,
            random_state=RANDOM_STATE,
        ),
    ),
    (
        "RandomForestClassifier, max_features=None",
        RandomForestClassifier(
            warm_start=True,
            max_features=None,
            oob_score=True,
            random_state=RANDOM_STATE,
        ),
    ),
]

# Map a classifier name to a list of (<n_estimators>, <error rate>) pairs.
error_rate = OrderedDict((label, []) for label, _ in ensemble_clfs)

# Range of `n_estimators` values to explore.
min_estimators = 15
max_estimators = 150

for label, clf in ensemble_clfs:
    for i in range(min_estimators, max_estimators + 1, 5):
        clf.set_params(n_estimators=i)
        clf.fit(X, y)

        # Record the OOB error for each `n_estimators=i` setting.
        oob_error = 1 - clf.oob_score_
        error_rate[label].append((i, oob_error))

# Generate the "OOB error rate" vs. "n_estimators" plot.
for label, clf_err in error_rate.items():
    xs, ys = zip(*clf_err)
    plt.plot(xs, ys, label=label)

plt.xlim(min_estimators, max_estimators)
plt.xlabel("n_estimators")
plt.ylabel("OOB error rate")
plt.legend(loc="upper right")
plt.show()

Sign up for FREE 3 months of Amazon Music. YOU MUST NOT MISS.