Welcome to Chapter 11!
In Chapter 10: Pipelines, we learned how to chain multiple complex steps into a single "conveyor belt" of machine learning. We have built custom models, processed text, and handled mixed data types.
But there is one scary question we haven't asked yet: How do we know our code isn't broken?
Imagine you are a car manufacturer. You build a new car. Do you just sell it immediately? No. You crash it into a wall to ensure the airbag deploys. You drive it in the rain to ensure the wipers work.
In software, this is called Testing.
The Problem:
0.1 + 0.2 equals 0.30000000000000004. If you check if result == 0.3, your test fails!
The Solution: Scikit-learn provides Testing Utilities in sklearn.utils._testing. These are special helper functions designed specifically to test machine learning code.
We want to verify the math of a simple function.
We will use functions that start with assert_. In programming, an "assertion" is a statement that says: "I bet this is true. If I am wrong, stop everything and yell at me."
assert_array_equal: Checks if two lists of integers or strings are exactly the same.assert_allclose: Checks if two lists of floats are approximately the same (ignores tiny computer math errors).assert_raise_message: Checks if a specific error message is printed when things go wrong.Let's write a test script using these utilities.
Computers often make tiny rounding errors. Let's see how scikit-learn handles this.
import numpy as np
from sklearn.utils._testing import assert_allclose
# The "True" value
expected = [1.0, 2.0, 3.0]
# The "Calculated" value (slightly off due to math)
calculated = [1.0, 2.0000000001, 2.9999999999]
# This passes because they are "close enough"
assert_allclose(expected, calculated)
print("Test Passed: The numbers are effectively equal!")
Explanation: If we used assert expected == calculated, it would fail. assert_allclose knows that 2.9999999999 is basically 3.0.
For integers (like counting apples) or strings (like class names), we expect exact matches.
from sklearn.utils._testing import assert_array_equal
# True labels vs Predicted labels
y_true = ["cat", "dog", "cat"]
y_pred = ["cat", "dog", "cat"]
# This checks exact equality
assert_array_equal(y_true, y_pred)
print("Test Passed: Arrays are identical.")
This sounds weird, but sometimes we want our code to fail. If a user tries to predict the price of a house using the word "Banana" as the size, our model should shout "Invalid Input!".
We use assert_raise_message to verify the shouting happens.
from sklearn.utils._testing import assert_raise_message
def calculate_square_root(x):
if x < 0:
raise ValueError("Input must be positive")
return np.sqrt(x)
# We expect a ValueError containing the text "positive"
# when we pass -5
with assert_raise_message(ValueError, "positive"):
calculate_square_root(-5)
print("Test Passed: The function correctly raised an error.")
Explanation: The code inside the with block must crash. If calculate_square_root(-5) accidentally worked and returned a number, the test would fail!
Sometimes, we know our code generates a warning (like "Function X is deprecated"), but we want our test output to be clean. We can tell the test suite to shut up about it.
from sklearn.utils._testing import ignore_warnings
@ignore_warnings(category=UserWarning)
def noisy_function():
import warnings
warnings.warn("This is a loud warning!")
return True
# Run it. No warning will be printed to the console.
noisy_function()
print("Test Passed: Silence achieved.")
These utilities are wrappers around the popular numpy.testing module and the standard Python unittest framework. They act as "judges" for your code.
When you call assert_allclose, a detailed comparison happens:
If the math said "No, Delta is too big," Utility would raise an AssertionError, stopping your program.
The code for these utilities resides in sklearn/utils/_testing.py.
Here is a simplified Python concept of how assert_raise_message is implemented. It uses a Python feature called a Context Manager (__enter__ and __exit__).
# Simplified logic of assert_raise_message
class SimpleAssertRaise:
def __init__(self, expected_error, expected_msg):
self.expected_error = expected_error
self.expected_msg = expected_msg
def __enter__(self):
# Start listening
return self
def __exit__(self, exc_type, exc_value, traceback):
# 1. Check if an error occurred at all
if exc_type is None:
raise AssertionError("No error was raised!")
# 2. Check if it was the RIGHT error (e.g., ValueError)
if not issubclass(exc_type, self.expected_error):
raise AssertionError(f"Wrong error type: {exc_type}")
# 3. Check if the message matches
if self.expected_msg not in str(exc_value):
raise AssertionError("Message didn't match!")
return True # Suppress the actual error so testing continues
Explanation:
__enter__: The test starts.__exit__: This runs after the code block finishes (or crashes).
In Chapter 1: Base API, we built a MajorityClassifier. To test it properly, we wouldn't just write one test. We would check:
fit work with a valid array? (assert_array_equal on shapes).predict return the expected class?X is empty? (assert_raise_message).Scikit-learn developers use these utilities every day to ensure that when they add a new feature (like the Ensembles from Chapter 7), they don't accidentally break the Linear Models from Chapter 3.
In this chapter, we learned:
assert_allclose helps us compare decimal numbers without worrying about tiny math errors.assert_raise_message ensures our models complain correctly when fed garbage.numpy and unittest tools.We have learned how to write our own tests. But wouldn't it be nice if scikit-learn had a built-in checklist to verify that our custom model follows all the rules of the library?
It does! It's called Common Tests.
Generated by Code IQ