Decision Tree Classification
Do not miss this exclusive book on Binary Tree Problems. Get it now for free.
In this article at OpenGenus, we have explained the concept of Decision Tree Classification in depth along with model implementation in Python.
Table of contents:
I. Introduction
II. Model Workings
III. Model code using Python libraries
IV. Model code from scratch without using ML based Python libraries
V. Summary
I. Introduction
Decision Tree Classification is one of the popular and widely used machine learning algorithms that is easy to understand and interpret. It is a supervised learning algorithm that can be used for both regression and classification tasks. The algorithm works by building a tree-like model of decisions and their possible consequences. Each internal node of the tree represents a test on a particular feature, each branch represents the outcome of the test, and each leaf node represents a class label or a continuous value. The algorithm uses the tree to make decisions by traversing from the root node to the leaf node based on the input features.
Decision Tree Classification is used in many fields such as finance, healthcare, and marketing, where it is used to analyze customer behavior, predict financial risk, and diagnose diseases. It is a popular algorithm in the industry due to its simplicity and ability to handle both numerical and categorical data.
In the next section, we will discuss how Decision Tree Classification works and how it can be used to solve classification problems.
II. Model Workings
Decision Tree Classification works by recursively splitting the dataset into subsets based on the values of the features, until the subsets are as pure as possible. The algorithm uses a decision tree to represent all the possible outcomes and their corresponding decisions.
The tree is constructed by choosing the feature that provides the best split of the dataset. The best split is the one that maximizes the information gain, which is a measure of the decrease in entropy or impurity of the dataset after the split. The algorithm calculates the entropy or impurity of the dataset using a measure such as Gini impurity or entropy. The impurity of a dataset is the measure of how mixed the dataset is with respect to the class labels.
The tree-building process continues until all the data points belong to the same class or the maximum depth of the tree is reached. The maximum depth of the tree is a hyperparameter that can be adjusted to prevent overfitting or underfitting.
Once the tree is built, it can be used to make predictions for new data points by traversing the tree from the root node to the leaf node based on the values of the input features.
In the next section, we will discuss how to implement Decision Tree Classification using Python libraries.
III. Model code using Python libraries
Python provides several popular libraries for implementing Decision Tree Classification, such as scikit-learn, PyTorch, and TensorFlow. In this section, we will focus on scikit-learn, which is a widely used Python library for machine learning.
Scikit-learn provides a DecisionTreeClassifier class that can be used to build a Decision Tree Classification model. The class provides several hyperparameters that can be adjusted to control the complexity and performance of the model. Some of the important hyperparameters are:
criterion
: The measure used to evaluate the quality of a split. It can be "gini" or "entropy".max_depth
: The maximum depth of the tree. It controls the complexity of the model and prevents overfitting.min_samples_split
: The minimum number of samples required to split an internal node. It controls the minimum size of the subset.min_samples_leaf
: The minimum number of samples required to be at a leaf node. It controls the minimum size of the leaf node.
Here's an example of how to build a Decision Tree Classification model using scikit-learn:
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Load the dataset
iris = load_iris()
X = iris.data
y = iris.target
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create a Decision Tree Classifier
clf = DecisionTreeClassifier(criterion="gini", max_depth=3, min_samples_split=4, min_samples_leaf=2)
# Train the classifier on the training set
clf.fit(X_train, y_train)
# Make predictions on the testing set
y_pred = clf.predict(X_test)
# Evaluate the accuracy of the model
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
In this example, we first load the iris dataset using the load_iris()
function from scikit-learn. We then split the dataset into training and testing sets using the train_test_split()
function. Next, we create a Decision Tree Classifier with hyperparameters criterion="gini"
, max_depth=3
, min_samples_split=4
, and min_samples_leaf=2
. We train the classifier on the training set using the fit()
method and make predictions on the testing set using the predict()
method. Finally, we evaluate the accuracy of the model using the accuracy_score()
function from scikit-learn.
In the next section, we will discuss how to implement Decision Tree Classification without using machine learning libraries.
IV. Model code from scratch without using ML based Python libraries
Implementing a Decision Tree Classification model from scratch without using any machine learning libraries can be challenging but also rewarding as it provides a deeper understanding of how the algorithm works. Here, we will implement the ID3 algorithm, which is one of the classic Decision Tree algorithms.
The ID3 algorithm works by selecting the feature that provides the maximum information gain, which is the difference between the entropy of the parent node and the weighted sum of the entropies of the child nodes. The entropy of a node is a measure of the impurity of the node with respect to the class labels. A node is pure if all the data points in the node belong to the same class.
Here's an example of how to implement the ID3 algorithm from scratch:
import numpy as np
def entropy(y):
"""Calculates entropy of a list of labels"""
_, counts = np.unique(y, return_counts=True)
probabilities = counts / counts.sum()
return sum(probabilities * -np.log2(probabilities))
class DecisionTreeClassifier:
def __init__(self, max_depth=None):
self.max_depth = max_depth
def fit(self, X, y):
self.n_classes = len(np.unique(y))
self.n_features = X.shape[1]
self.tree = self._grow_tree(X, y)
def predict(self, X):
return np.array([self._traverse_tree(x, self.tree) for x in X])
def _grow_tree(self, X, y, depth=0):
# Count of each class in current node
num_samples_per_class = [np.sum(y == i) for i in range(self.n_classes)]
# Predict the class with the most samples
predicted_class = np.argmax(num_samples_per_class)
# Create leaf node if we have reached the maximum depth or if the node has too few samples
if depth == self.max_depth or self.n_classes == 1:
return predicted_class
# Select the best feature to split on
best_feature, best_threshold = self._best_split(X, y, num_samples_per_class, self.n_features)
# Grow the left and right subtrees
left_idxs = X[:, best_feature] < best_threshold
right_idxs = X[:, best_feature] >= best_threshold
left = self._grow_tree(X[left_idxs, :], y[left_idxs], depth+1)
right = self._grow_tree(X[right_idxs, :], y[right_idxs], depth+1)
# Return a new decision node
return DecisionNode(best_feature, best_threshold, left, right)
def _best_split(self, X, y, num_samples_per_class, n_features):
best_feature, best_threshold = None, None
max_gain = -1
# Calculate the entropy of the parent node
parent_entropy = entropy(y)
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
# Split the data into two subsets
left_idxs = X[:, feature] < threshold
right_idxs = ~left_idxs
# Skip if split does not create two subsets
if left_idxs.sum() == 0 or right_idxs.sum() == 0:
continue
# Calculate the information gain of the split
left_entropy = entropy(y[left_idxs])
right_entropy = entropy(y[right_idxs])
n_left, n_right = left_idxs.sum(), right_idxs.sum()
gain = parent_entropy - (n_left/len(y) * left_entropy + n_right/len(y) * right_entropy)
# Update the best split
if gain > max_gain:
best_feature, best_threshold = feature, threshold
max_gain = gain
return best_feature, best_threshold
def _traverse_tree(self, x, node):
if isinstance(node, DecisionNode):
if x[node.feature] < node.threshold:
return self._traverse_tree(x, node.left)
else:
return self._traverse_tree(x, node.right)
else:
return node
class DecisionNode:
def __init__(self, feature, threshold, left, right):
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
The above is the complete code for making a basic version of Decision Tree Classification Model from scratch.
The code bellow shows how to use this mathematically created model similar to that using a ML Library:
# Example usage
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Load iris dataset
iris = load_iris()
X, y = iris.data, iris.target
# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create and train the decision tree
clf = DecisionTreeClassifier(max_depth=3)
clf.fit(X_train, y_train)
# Make predictions on the test set
y_pred = clf.predict(X_test)
# Calculate accuracy
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy: {acc:.3f}")
This code block shows an example usage of the DecisionTreeClassifier
class for a classification problem using the iris dataset.
V. Summary
In this article at OpenGenus, we explored the decision tree classification model in machine learning. Decision trees are easy to interpret and can handle both numerical and categorical data. We saw how decision trees work to classify data by recursively splitting the dataset into smaller subsets using the features that are most informative for prediction.
The advantages of using decision trees for classification problems include:
Interpretability
: The decisions made by the model are easy to interpret and can provide insights into the relationships between variables.Ease of use
: Decision trees are simple to understand and require little to no data preprocessing.Handles mixed data
: Decision trees can handle both numerical and categorical data.
Decision trees work best for small to medium-sized datasets and problems with non-linear relationships between variables. They are also useful for exploring the structure of a dataset and identifying important features for prediction.
Overall, decision trees are a powerful tool in machine learning and can be applied to a wide range of classification problems. By understanding how decision trees work and their strengths and weaknesses, you can make informed decisions about when and how to use them.
Sign up for FREE 3 months of Amazon Music. YOU MUST NOT MISS.