Source code for tic.interpret.shap
from typing import Dict, List, Optional
import shap
import pandas as pd
from sklearn.base import BaseEstimator
def _create_explainer(
clf: BaseEstimator,
X_train: pd.DataFrame,
**kwargs
):
'''
Creates kernel explainer from SHAP.
This is the most complete explainer in terms of model type coverage.
'''
return shap.KernelExplainer(
model=clf.predict_proba,
data=X_train,
**kwargs
)
[docs]def explain_local(
clf: BaseEstimator,
X_train: pd.DataFrame,
instance: pd.Series,
class_names: List,
sample_size: Optional[int] = 100,
explainer_kwargs: Optional[Dict] = {},
explanation_kwargs: Optional[Dict] = {}
):
'''
Creates an explainer and explains the given instance using SHAP.
Args:
clf : Fitted classifier from sklearn
X_train: data that was used to train the classifier
instance: instance to explain
class_names: names of class labels
sample_size: how many data points are used to create the SHAP values
explainer_kwargs: Keyword args passed during explainer initialization
explanation_kwargs: Keyword args passed for explanation
Returns:
Enriched SHAP explanation including figure
'''
explainer = _create_explainer(
clf=clf,
X_train=X_train,
**explainer_kwargs
)
shap_values = explainer.shap_values(instance, nsamples=sample_size)
figure = shap.force_plot(
base_value=explainer.expected_value[0],
shap_values=shap_values[0],
features=instance,
out_names=class_names,
matplotlib=True,
show=False,
**explanation_kwargs
)
return {
'explainer': explainer,
'shap_values': shap_values,
'figure': figure
}
[docs]def explain_global(
clf: BaseEstimator,
X_train: pd.DataFrame,
X_test: pd.DataFrame,
class_names: List,
sample_size: Optional[int] = 100,
explainer_kwargs: Optional[Dict] = {},
explanation_kwargs: Optional[Dict] = {}
):
'''
Creates an explainer and explanations for a given dataset using SHAP.
Args:
clf : Fitted classifier from sklearn
X_train: data that was used to train the classifier
X_test: data that should be explained
class_names: names of class labels
sample_size: how many data points are used to create the SHAP values
explainer_kwargs: Keyword args passed during explainer initialization
explanation_kwargs: Keyword args passed for explanation
Returns:
Enriched SHAP explanation including interactive figure
'''
explainer = _create_explainer(
clf=clf,
X_train=X_train,
**explainer_kwargs
)
shap_values = explainer.shap_values(X_test, nsamples=sample_size)
figure = shap.force_plot(
base_value=explainer.expected_value[0],
shap_values=shap_values[0],
features=X_test,
out_names=class_names,
show=False,
**explanation_kwargs
)
return {
'explainer': explainer,
'shap_values': shap_values,
'figure': figure
}