You are on page 1of 5

Topic homework: Interpretability with SHAP

In this notebook we explore the open source library SHAP for interpreting black-box machine learning
models. SHAP comes with many clearn benefits:

Allows us to understand the key factors of hetergoeneity in complex models - such as neural
networks or boosted trees
Can be caluclated quickly and visually expressed - no need to fit multiple models

Let's see an example of SHAP plots in action:

pip install econml

pip install shap

## Ignore warnings
from econml.dml import CausalForestDML, LinearDML, NonParamDML
from econml.dr import DRLearner
from econml.metalearners import DomainAdaptationLearner, XLearner
from econml.iv.dr import LinearIntentToTreatDRIV
import numpy as np
import scipy.special
import matplotlib.pyplot as plt
import shap
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.linear_model import Lasso

import sklearn

np.random.seed(123)
n_samples = 5000
n_features = 10
true_te = lambda X: (X[:, 0]>0) * X[:, 0]
X = np.random.normal(0, 1, size=(n_samples, n_features))
W = np.random.normal(0, 1, size=(n_samples, n_features))
T = np.random.binomial(1, scipy.special.expit(X[:, 0]))
y = true_te(X) * T + 5.0 * X[:, 0] + np.random.normal(0, .1, size=(n_samples,))
X_test = X[:min(100, n_samples)].copy()
X_test[:, 0] = np.linspace(np.percentile(X[:, 0], 1), np.percentile(X[:, 0], 99), min(100,

Here, we see a Forest Double Machine Learning Estimator which is a forest model with residualization to
synthetic data. The data was genererated specifically with the first feature having a strong casual effect,
while the other are linear multiples of random noise.
est = CausalForestDML(random_state=123)
est.fit(y, T, X=X, W=W)
shap_values = est.shap_values(X[:20])
shap.plots.beeswarm(shap_values['Y0']['T0'])

'normalize' was deprecated in version 1.0 and will be removed in 1.2. Please leave th
'normalize' was deprecated in version 1.0 and will be removed in 1.2. Please leave th
'normalize' was deprecated in version 1.0 and will be removed in 1.2. Please leave th
'normalize' was deprecated in version 1.0 and will be removed in 1.2. Please leave th

Its important to note that the shapley value is calculated for each of the 20 rows in the data given to the
est.shap_values() function. The plot shows those 20 points with a random up and down jitter to avoid
overlapping points. As a result, there is not a single shapley value per feature, but a shapley value per
feature per observation.

The shap plot was clearly to indicate that high values in the first feature has significant impact on the
model output. But what does this impact mean?

We investigate the documentation of shap plots to better understand what SHAP represents:

An easier example of SHAP would be to compare to a linear model with coefficients:

# a classic housing price dataset


X,y = shap.datasets.boston()
X100 = shap.utils.sample(X, 100) # 100 instances for use as the background distribution

# a simple linear model


model = sklearn.linear_model.LinearRegression()
model.fit(X, y)
print("Model coefficients:\n")
for i in range(X.shape[1]):
print(X.columns[i], "=", model.coef_[i].round(4))

Model coefficients:

CRIM = -0.108
ZN = 0.0464
INDUS = 0.0206
CHAS = 2.6867
NOX = -17.7666
RM = 3.8099
AGE = 0.0007
DIS = -1.4756
RAD = 0.306
TAX = -0.0123
PTRATIO = -0.9527
B = 0.0093
LSTAT = -0.5248

Here we see the linear coefficients we are familiar with. However the value of the coefficients depends on
the scale of the feature, thus its absolute value is not indicative of its importance.
Instead, the authors of SHAP suggest a partial dependance plot, here we see it plotted for one feature of
AGE.

shap.plots.partial_dependence(
"RM", model.predict, X100, ice=False,
model_expected_value=True, feature_expected_value=True
)

We see that, because this model is linear as AGE increase the expeceted value of the models predictions
(with all the other features marginalized out) is shown in the blue line.
To calculate SHAP values, we attempt to find an existing model 𝑓 with a subset of features 𝑆 , which is
done by integrating out the other features using conditional expected value formulation. As a result, we
see how the predicted function changes with the changing feature.

explainer = shap.Explainer(model.predict, X100)


shap_values = explainer(X)

# make a standard partial dependence plot


sample_ind = 18
shap.partial_dependence_plot(
"RM", model.predict, X100, model_expected_value=True,
feature_expected_value=True, ice=False,
shap_values=shap_values[sample_ind:sample_ind+1,:]
)

Permutation explainer: 507it [00:24, 15.98it/s]

From here, at a given observation 𝑥𝑖 (recall that shapley values are calculated at each observation) the
deviation of the model with respect to the model's mean (shown in the red line above) is approximately
-3.09 which is the shapley value of for this observation and feature.

shap.plots.waterfall(shap_values[sample_ind], max_display=14)
In this notebook we took deeper dive into Shapley plots and learned that:

Shapley plots can show the impact of a feature, and is not affected by scale as linear coefficients
are
Shapley values are calculated per observation given per feature, it marginalize out other features
and looks at the change in prediction observation of the given feature with respect to the mean
Shapley plots are a fast and visual way of making complex models more interpretable.

You might also like