Welcome to Chapter 6!
In Chapter 5: Clustering, we learned how to group data without any labels. Before that, in Chapter 3: Linear Models, we learned how to draw straight lines to make predictions.
But what if the answer isn't a straight line? What if the logic is more like a game of "20 Questions"?
Imagine you are a doctor trying to diagnose a patient. You don't do a complex mathematical calculation immediately. Instead, you ask a series of Yes/No questions:
This logic forms a Tree. You start at the top (the Root) and travel down different branches based on the answers until you reach a diagnosis (a Leaf).
The Problem: Writing out these rules by hand ("If X > 5 and Y < 2...") is tedious and prone to error. The Solution: Decision Trees. We give the computer the data, and it figures out the best questions to ask to separate the data perfectly.
We want to distinguish between Apples and Oranges based on two features:
We want the model to build a flowchart that can tell us which fruit we are holding.
A Decision Tree is built of three main parts:
How does the tree decide which question to ask? It looks for "Purity."
The tree tries to find a question that splits a mixed bucket into two purer buckets. This is measured using a metric called Gini Impurity.
Let's build a DecisionTreeClassifier.
We have 4 fruits.
from sklearn.tree import DecisionTreeClassifier
# Features: [Weight, Texture]
X = [[160, 1], [170, 2], [140, 10], [130, 9]]
# Labels: 0 = Orange, 1 = Apple
y = [0, 0, 1, 1]
We instantiate the class and fit it, just like in previous chapters.
# Create the Tree model
clf = DecisionTreeClassifier()
# Teach the tree to distinguish fruits
clf.fit(X, y)
Now we have a new fruit: Weight 150g, Texture 8 (Smooth).
# Predict for [150g, Smooth]
prediction = clf.predict([[150, 8]])
# 0 is Orange, 1 is Apple
print("Fruit is:", "Apple" if prediction[0] == 1 else "Orange")
# Output: Fruit is: Apple
The best part about Trees is that they aren't "black boxes." We can see exactly why it made that decision.
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# Draw the tree
plot_tree(clf, feature_names=['Weight', 'Texture'], filled=True)
plt.show()
Result: You will see a diagram. The top box likely says Weight <= 150.0.
The tree learned that Weight was the most important factor to distinguish these fruits.
You might wonder: How did it know to pick 150.0 as the cutoff number? Why not 145? Why not ask about Texture first?
The algorithm used is called CART (Classification and Regression Trees). It uses a "Greedy" strategy.
_tree.pyxChecking every split point for every feature is computationally expensive. If you have 1 million rows, a Python loop would take forever.
To solve this, scikit-learn implements the core logic in Cython. The file is sklearn/tree/_tree.pyx. This allows the code to run at C-speed.
The internal structure relies on two main helper classes:
Here is a conceptual Python version of what the fast Cython Splitter does.
# Conceptual logic inside sklearn/tree/_tree.pyx
# (This actually runs in compiled C code)
def find_best_split(X, y):
best_gini = 1.0
best_split = None
# Loop 1: Go through every feature (Weight, Texture)
for feature_index in range(n_features):
# Loop 2: Go through every unique value in that feature
possible_thresholds = unique_values(X[:, feature_index])
for threshold in possible_thresholds:
# Try splitting here
left_y, right_y = split_data(y, feature_index, threshold)
# Calculate impurity (messiness)
gini = calculate_gini(left_y, right_y)
# Keep the winner
if gini < best_gini:
best_gini = gini
best_split = (feature_index, threshold)
return best_split
Explanation:
best_split is found, it creates a Node and repeats the process for the data on the left and right.In this chapter, we learned:
_tree.pyx) to do it fast.Trees are powerful, but they have a weakness: they can easily over-complicate things (overfitting). What if, instead of one tree, we asked a committee of 100 trees to vote on the answer?
That is the topic of the next chapter.
Generated by Code IQ