Professional Documents
Culture Documents
This notebook contains all the sample code and solutions to the exercises in chapter 7.
Setup
First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check
that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as
well as Scikit-Learn ≥0.20.
# Common imports
import numpy as np
import os
Voting classifiers
heads_proba = 0.51
coin_tosses = (np.random.rand(10000, 10) < heads_proba).astype(np.int32)
cumulative_heads_ratio = np.cumsum(coin_tosses, axis=0) / np.arange(1, 10001).reshape(-1, 1)
plt.figure(figsize=(8,3.5))
plt.plot(cumulative_heads_ratio)
plt.plot([0, 10000], [0.51, 0.51], "k--", linewidth=2, label="51%")
plt.plot([0, 10000], [0.5, 0.5], "k-", label="50%")
plt.xlabel("Number of coin tosses")
plt.ylabel("Heads ratio")
plt.legend(loc="lower right")
plt.axis([0, 10000, 0.42, 0.58])
save_fig("law_of_large_numbers_plot")
plt.show()
Note: to be future-proof, we set solver="lbfgs" , n_estimators=100 , and gamma="scale" since these will be the default values in
upcoming Scikit-Learn versions.
voting_clf = VotingClassifier(
estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],
voting='hard')
voting_clf.fit(X_train, y_train)
LogisticRegression 0.864
RandomForestClassifier 0.896
SVC 0.896
VotingClassifier 0.912
Soft voting:
voting_clf = VotingClassifier(
estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],
voting='soft')
voting_clf.fit(X_train, y_train)
LogisticRegression 0.864
RandomForestClassifier 0.896
SVC 0.896
VotingClassifier 0.92
Bagging ensembles
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
bag_clf = BaggingClassifier(
DecisionTreeClassifier(random_state=42), n_estimators=500,
max_samples=100, bootstrap=True, random_state=42)
bag_clf.fit(X_train, y_train)
y_pred = bag_clf.predict(X_test)
0.904
tree_clf = DecisionTreeClassifier(random_state=42)
tree_clf.fit(X_train, y_train)
y_pred_tree = tree_clf.predict(X_test)
print(accuracy_score(y_test, y_pred_tree))
0.856
Random Forests
bag_clf = BaggingClassifier(
DecisionTreeClassifier(splitter="random", max_leaf_nodes=16, random_state=42),
n_estimators=500, max_samples=1.0, bootstrap=True, random_state=42)
bag_clf.fit(X_train, y_train)
y_pred = bag_clf.predict(X_test)
y_pred_rf = rnd_clf.predict(X_test)
0.976
rnd_clf.feature_importances_
plt.figure(figsize=(6, 4))
for i in range(15):
tree_clf = DecisionTreeClassifier(max_leaf_nodes=16, random_state=42 + i)
indices_with_replacement = np.random.randint(0, len(X_train), len(X_train))
tree_clf.fit(X[indices_with_replacement], y[indices_with_replacement])
plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.02, contour=False)
plt.show()
Out-of-Bag evaluation
bag_clf = BaggingClassifier(
DecisionTreeClassifier(random_state=42), n_estimators=500,
bootstrap=True, oob_score=True, random_state=40)
bag_clf.fit(X_train, y_train)
bag_clf.oob_score_
0.9013333333333333
bag_clf.oob_decision_function_
array([[0.31746032, 0.68253968],
[0.34117647, 0.65882353],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.08379888, 0.91620112],
[0.31693989, 0.68306011],
[0.02923977, 0.97076023],
[0.97687861, 0.02312139],
[0.97765363, 0.02234637],
[0.74404762, 0.25595238],
[0. , 1. ],
[0.71195652, 0.28804348],
[0.83957219, 0.16042781],
[0.97777778, 0.02222222],
[0.0625 , 0.9375 ],
[0. , 1. ],
[0.97297297, 0.02702703],
[0.95238095, 0.04761905],
[1. , 0. ],
[0.01704545, 0.98295455],
[0.38947368, 0.61052632],
[0.88700565, 0.11299435],
[1. , 0. ],
[0.96685083, 0.03314917],
[0. , 1. ],
[0.99428571, 0.00571429],
[1. , 0. ],
[0. , 1. ],
[0.64804469, 0.35195531],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.13402062, 0.86597938],
[1. , 0. ],
[0. , 1. ],
[0.36065574, 0.63934426],
[0. , 1. ],
[1. , 0. ],
[0.27093596, 0.72906404],
[0.34146341, 0.65853659],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0.00531915, 0.99468085],
[0.98265896, 0.01734104],
[0.91428571, 0.08571429],
[0.97282609, 0.02717391],
[0.97029703, 0.02970297],
[0. , 1. ],
[0.06134969, 0.93865031],
[0.98019802, 0.01980198],
[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
[0.97790055, 0.02209945],
[0.79473684, 0.20526316],
[0.41919192, 0.58080808],
[0.99473684, 0.00526316],
[0. , 1. ],
[0.67613636, 0.32386364],
[1. , 0. ],
[1. , 0. ],
[0.87356322, 0.12643678],
[1. , 0. ],
[0.56140351, 0.43859649],
[0.16304348, 0.83695652],
[0.67539267, 0.32460733],
[0.90673575, 0.09326425],
[0. , 1. ],
[0.16201117, 0.83798883],
[0.89005236, 0.10994764],
[1. , 0. ],
[0. , 1. ],
[0.995 , 0.005 ],
[0. , 1. ],
[0.07272727, 0.92727273],
[0.05418719, 0.94581281],
[0.29533679, 0.70466321],
[1. , 0. ],
[0. , 1. ],
[0.81871345, 0.18128655],
[0.01092896, 0.98907104],
[0. , 1. ],
[0. , 1. ],
[0.22513089, 0.77486911],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
[0.9368932 , 0.0631068 ],
[0.76536313, 0.23463687],
[0. , 1. ],
[1. , 0. ],
[0.17127072, 0.82872928],
[0.65306122, 0.34693878],
[0. , 1. ],
[0.03076923, 0.96923077],
[0.49444444, 0.50555556],
[1. , 0. ],
[0.02673797, 0.97326203],
[0.98870056, 0.01129944],
[0.23121387, 0.76878613],
[0.5 , 0.5 ],
[0.9947644 , 0.0052356 ],
[0.00555556, 0.99444444],
[0.98963731, 0.01036269],
[0.25641026, 0.74358974],
[0.92972973, 0.07027027],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.80681818, 0.19318182],
[1. , 0. ],
[0.0106383 , 0.9893617 ],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[0.98181818, 0.01818182],
[1. , 0. ],
[0.01036269, 0.98963731],
[0.97752809, 0.02247191],
[0.99453552, 0.00546448],
[0.01960784, 0.98039216],
[0.18367347, 0.81632653],
[0.98387097, 0.01612903],
[0.29533679, 0.70466321],
[0.98295455, 0.01704545],
[0. , 1. ],
[0.00561798, 0.99438202],
[0.75138122, 0.24861878],
[0.38624339, 0.61375661],
[0.42708333, 0.57291667],
[0.86315789, 0.13684211],
[0.92964824, 0.07035176],
[0.05699482, 0.94300518],
[0.82802548, 0.17197452],
[0.01546392, 0.98453608],
[0. , 1. ],
[0.02298851, 0.97701149],
[0.96721311, 0.03278689],
[1. , 0. ],
[1. , 0. ],
[0.01041667, 0.98958333],
[0. , 1. ],
[0.0326087 , 0.9673913 ],
[0.01020408, 0.98979592],
[1. , 0. ],
[1. , 0. ],
[0.93785311, 0.06214689],
[1. , 0. ],
[1. , 0. ],
[0.99462366, 0.00537634],
[0. , 1. ],
[0.38860104, 0.61139896],
[0.32065217, 0.67934783],
[0. , 1. ],
[0. , 1. ],
[0.31182796, 0.68817204],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0.00588235, 0.99411765],
[0. , 1. ],
[0.98387097, 0.01612903],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0.62264151, 0.37735849],
[0.92344498, 0.07655502],
[0. , 1. ],
[0.99526066, 0.00473934],
[1. , 0. ],
[0.98888889, 0.01111111],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[0.06451613, 0.93548387],
[1. , 0. ],
[0.05154639, 0.94845361],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0.03278689, 0.96721311],
[1. , 0. ],
[0.95808383, 0.04191617],
[0.79532164, 0.20467836],
[0.55665025, 0.44334975],
[0. , 1. ],
[0.18604651, 0.81395349],
[1. , 0. ],
[0.93121693, 0.06878307],
[0.97740113, 0.02259887],
[1. , 0. ],
[0.00531915, 0.99468085],
[0. , 1. ],
[0.44623656, 0.55376344],
[0.86363636, 0.13636364],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[0.00558659, 0.99441341],
[0. , 1. ],
[0.96923077, 0.03076923],
[0. , 1. ],
[0.21649485, 0.78350515],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.98477157, 0.01522843],
[0.8 , 0.2 ],
[0.99441341, 0.00558659],
[0. , 1. ],
[0.08379888, 0.91620112],
[0.98984772, 0.01015228],
[0.01142857, 0.98857143],
[0. , 1. ],
[0.02747253, 0.97252747],
[1. , 0. ],
[0.79144385, 0.20855615],
[0. , 1. ],
[0.90804598, 0.09195402],
[0.98387097, 0.01612903],
[0.20634921, 0.79365079],
[0.19767442, 0.80232558],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
[0.20338983, 0.79661017],
[0.98181818, 0.01818182],
[0. , 1. ],
[1. , 0. ],
[0.98969072, 0.01030928],
[0. , 1. ],
[0.48663102, 0.51336898],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.07821229, 0.92178771],
[0.11176471, 0.88823529],
[0.99415205, 0.00584795],
[0.03015075, 0.96984925],
[1. , 0. ],
[0.40837696, 0.59162304],
[0.04891304, 0.95108696],
[0.51595745, 0.48404255],
[0.51898734, 0.48101266],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.59903382, 0.40096618],
[0. , 1. ],
[1. , 0. ],
[0.24157303, 0.75842697],
[0.81052632, 0.18947368],
[0.08717949, 0.91282051],
[0.99453552, 0.00546448],
[0.82142857, 0.17857143],
[0. , 1. ],
[0. , 1. ],
[0.125 , 0.875 ],
[0.04712042, 0.95287958],
[0. , 1. ],
[1. , 0. ],
[0.89150943, 0.10849057],
[0.1978022 , 0.8021978 ],
[0.95238095, 0.04761905],
[0.00515464, 0.99484536],
[0.609375 , 0.390625 ],
[0.07692308, 0.92307692],
[0.99484536, 0.00515464],
[0.84210526, 0.15789474],
[0. , 1. ],
[0.99484536, 0.00515464],
[0.95876289, 0.04123711],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0.26903553, 0.73096447],
[0.98461538, 0.01538462],
[1. , 0. ],
[0. , 1. ],
[0.00574713, 0.99425287],
[0.85142857, 0.14857143],
[0. , 1. ],
[1. , 0. ],
[0.76506024, 0.23493976],
[0.8969697 , 0.1030303 ],
[1. , 0. ],
[0.73333333, 0.26666667],
[0.47727273, 0.52272727],
[0. , 1. ],
[0.92473118, 0.07526882],
[0. , 1. ],
[1. , 0. ],
[0.87709497, 0.12290503],
[1. , 0. ],
[1. , 0. ],
[0.74752475, 0.25247525],
[0.09146341, 0.90853659],
[0.44329897, 0.55670103],
[0.22395833, 0.77604167],
[0. , 1. ],
[0.87046632, 0.12953368],
[0.78212291, 0.21787709],
[0.00507614, 0.99492386],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0.02884615, 0.97115385],
[0.96571429, 0.03428571],
[0.93478261, 0.06521739],
[1. , 0. ],
[0.49756098, 0.50243902],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0.01604278, 0.98395722],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0.96987952, 0.03012048],
[0. , 1. ],
[0.05747126, 0.94252874],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0.98989899, 0.01010101],
[0.01675978, 0.98324022],
[1. , 0. ],
[0.13541667, 0.86458333],
[0. , 1. ],
[0.00546448, 0.99453552],
[0. , 1. ],
[0.41836735, 0.58163265],
[0.11309524, 0.88690476],
[0.22110553, 0.77889447],
[1. , 0. ],
[0.97647059, 0.02352941],
[0.22826087, 0.77173913],
[0.98882682, 0.01117318],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[0.96428571, 0.03571429],
[0.33507853, 0.66492147],
[0.98235294, 0.01764706],
[1. , 0. ],
[0. , 1. ],
[0.99465241, 0.00534759],
[0. , 1. ],
[0.06043956, 0.93956044],
[0.97619048, 0.02380952],
[1. , 0. ],
[0.03108808, 0.96891192],
[0.57291667, 0.42708333]])
0.912
Feature importance
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap = mpl.cm.hot,
interpolation="nearest")
plt.axis("off")
plot_digit(rnd_clf.feature_importances_)
save_fig("mnist_feature_importance_plot")
plt.show()
ada_clf = AdaBoostClassifier(
DecisionTreeClassifier(max_depth=1), n_estimators=200,
algorithm="SAMME.R", learning_rate=0.5, random_state=42)
ada_clf.fit(X_train, y_train)
AdaBoostClassifier(algorithm='SAMME.R',
base_estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=1,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False, random_state=None,
splitter='best'),
learning_rate=0.5, n_estimators=200, random_state=42)
plot_decision_boundary(ada_clf, X, y)
m = len(X_train)
save_fig("boosting_plot")
plt.show()
['base_estimator_',
'classes_',
'estimator_errors_',
'estimator_weights_',
'estimators_',
'feature_importances_',
'n_classes_']
Gradient Boosting
np.random.seed(42)
X = np.random.rand(100, 1) - 0.5
y = 3*X[:, 0]**2 + 0.05 * np.random.randn(100)
y2 = y - tree_reg1.predict(X)
tree_reg2 = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_reg2.fit(X, y2)
y3 = y2 - tree_reg2.predict(X)
tree_reg3 = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_reg3.fit(X, y3)
y_pred
array([0.75026781])
plt.figure(figsize=(11,11))
plt.subplot(321)
plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="$h_1(x_1)$", style="g-",
data_label="Training set")
plt.ylabel("$y$", fontsize=16, rotation=0)
plt.title("Residuals and tree predictions", fontsize=16)
plt.subplot(322)
plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="$h(x_1) = h_1(x_1)$",
data_label="Training set")
plt.ylabel("$y$", fontsize=16, rotation=0)
plt.title("Ensemble predictions", fontsize=16)
plt.subplot(323)
plot_predictions([tree_reg2], X, y2, axes=[-0.5, 0.5, -0.5, 0.5], label="$h_2(x_1)$", style="g-",
data_style="k+", data_label="Residuals")
plt.ylabel("$y - h_1(x_1)$", fontsize=16)
plt.subplot(324)
plot_predictions([tree_reg1, tree_reg2], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="$h(x_1) = h_1(x_1) +
h_2(x_1)$")
plt.ylabel("$y$", fontsize=16, rotation=0)
plt.subplot(325)
plot_predictions([tree_reg3], X, y3, axes=[-0.5, 0.5, -0.5, 0.5], label="$h_3(x_1)$", style="g-",
data_style="k+")
plt.ylabel("$y - h_1(x_1) - h_2(x_1)$", fontsize=16)
plt.xlabel("$x_1$", fontsize=16)
plt.subplot(326)
plot_predictions([tree_reg1, tree_reg2, tree_reg3], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="$h(x_1) =
h_1(x_1) + h_2(x_1) + h_3(x_1)$")
plt.xlabel("$x_1$", fontsize=16)
plt.ylabel("$y$", fontsize=16, rotation=0)
save_fig("gradient_boosting_plot")
plt.show()
plt.sca(axes[0])
plot_predictions([gbrt], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="Ensemble predictions")
plt.title("learning_rate={}, n_estimators={}".format(gbrt.learning_rate, gbrt.n_estimators), fontsize=14)
plt.xlabel("$x_1$", fontsize=16)
plt.ylabel("$y$", fontsize=16, rotation=0)
plt.sca(axes[1])
plot_predictions([gbrt_slow], X, y, axes=[-0.5, 0.5, -0.1, 0.8])
plt.title("learning_rate={}, n_estimators={}".format(gbrt_slow.learning_rate, gbrt_slow.n_estimators),
fontsize=14)
plt.xlabel("$x_1$", fontsize=16)
save_fig("gbrt_learning_rate_plot")
plt.show()
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
min_error = np.min(errors)
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(errors, "b.-")
plt.plot([bst_n_estimators, bst_n_estimators], [0, min_error], "k--")
plt.plot([0, 120], [min_error, min_error], "k--")
plt.plot(bst_n_estimators, min_error, "ko")
plt.text(bst_n_estimators, min_error*1.2, "Minimum", ha="center", fontsize=14)
plt.axis([0, 120, 0, 0.01])
plt.xlabel("Number of trees")
plt.ylabel("Error", fontsize=16)
plt.title("Validation error", fontsize=14)
plt.subplot(122)
plot_predictions([gbrt_best], X, y, axes=[-0.5, 0.5, -0.1, 0.8])
plt.title("Best model (%d trees)" % bst_n_estimators, fontsize=14)
plt.ylabel("$y$", fontsize=16, rotation=0)
plt.xlabel("$x_1$", fontsize=16)
save_fig("early_stopping_gbrt_plot")
plt.show()
min_val_error = float("inf")
error_going_up = 0
for n_estimators in range(1, 120):
gbrt.n_estimators = n_estimators
gbrt.fit(X_train, y_train)
y_pred = gbrt.predict(X_val)
val_error = mean_squared_error(y_val, y_pred)
if val_error < min_val_error:
min_val_error = val_error
error_going_up = 0
else:
error_going_up += 1
if error_going_up == 5:
break # early stopping
print(gbrt.n_estimators)
61
Using XGBoost
try:
import xgboost
except ImportError as ex:
print("Error: the xgboost library is not installed.")
xgboost = None
12.9 ms ± 827 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Exercise solutions
1. to 7.
See Appendix A.
8. Voting Classifier
Exercise: Load the MNIST data and split it into a training set, a validation set, and a test set (e.g., use 50,000 instances for training, 10,000
for validation, and 10,000 for testing).
Exercise: Then train various classifiers, such as a Random Forest classifier, an Extra-Trees classifier, and an SVM.
The linear SVM is far outperformed by the other classifiers. However, let's keep it for now since it may improve the voting classifier's
performance.
Exercise: Next, try to combine them into an ensemble that outperforms them all on the validation set, using a soft or hard voting classifier.
named_estimators = [
("random_forest_clf", random_forest_clf),
("extra_trees_clf", extra_trees_clf),
("svm_clf", svm_clf),
("mlp_clf", mlp_clf),
]
voting_clf = VotingClassifier(named_estimators)
voting_clf.fit(X_train, y_train)
voting_clf.score(X_val, y_val)
0.9704
Let's remove the SVM to see if performance improves. It is possible to remove an estimator by setting it to None using set_params() like
this:
voting_clf.set_params(svm_clf=None)
voting_clf.estimators
[('random_forest_clf',
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
max_depth=None, max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None,
oob_score=False, random_state=42, verbose=0, warm_start=False)),
('extra_trees_clf',
ExtraTreesClassifier(bootstrap=False, class_weight=None, criterion='gini',
max_depth=None, max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None,
oob_score=False, random_state=42, verbose=0, warm_start=False)),
('svm_clf', None),
('mlp_clf',
MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
beta_2=0.999, early_stopping=False, epsilon=1e-08,
hidden_layer_sizes=(100,), learning_rate='constant',
learning_rate_init=0.001, max_iter=200, momentum=0.9,
n_iter_no_change=10, nesterovs_momentum=True, power_t=0.5,
random_state=42, shuffle=True, solver='adam', tol=0.0001,
validation_fraction=0.1, verbose=False, warm_start=False))]
voting_clf.estimators_
So we can either fit the VotingClassifier again, or just remove the SVM from the list of trained estimators:
del voting_clf.estimators_[2]
voting_clf.score(X_val, y_val)
0.9732
A bit better! The SVM was hurting performance. Now let's try using a soft voting classifier. We do not actually need to retrain the classifier, we
can just set voting to "soft" :
voting_clf.voting = "soft"
voting_clf.score(X_val, y_val)
0.9672
Once you have found one, try it on the test set. How much better does it perform compared to the individual classifiers?
voting_clf.voting = "hard"
voting_clf.score(X_test, y_test)
0.9704
The voting classifier only very slightly reduced the error rate of the best model in this case.
9. Stacking Ensemble
Exercise: Run the individual classifiers from the previous exercise to make predictions on the validation set, and create a new training set with
the resulting predictions: each training instance is a vector containing the set of predictions from all your classifiers for an image, and the
target is the image's class. Train a classifier on this new training set.
X_val_predictions
rnd_forest_blender.oob_score_
0.9696
You could fine-tune this blender or try other types of blenders (e.g., an MLPClassifier ), then select the best one using cross-validation, as
always.
Exercise: Congratulations, you have just trained a blender, and together with the classifiers they form a stacking ensemble! Now let's evaluate
the ensemble on the test set. For each image in the test set, make predictions with all your classifiers, then feed the predictions to the blender
to get the ensemble's predictions. How does it compare to the voting classifier you trained earlier?
y_pred = rnd_forest_blender.predict(X_test_predictions)
accuracy_score(y_test, y_pred)
0.9669
This stacking ensemble does not perform as well as the voting classifier we trained earlier, it's not quite as good as the best individual
classifier.