Source code for physlearn.supervised.interpretation.interpret_regressor

"""
The :mod:`physlearn.supervised.interpretation.interpret_regressor`
module provides SHAP utilities for regressor interpretability.
It includes the :class:`physlearn.ShapInterpret`
class.
"""

# Author: Alex Wozniakowski
# License: MIT

import shap

import sklearn.utils.multiclass
import sklearn.utils.validation

import matplotlib.pyplot as plt

from dataclasses import dataclass, field
from IPython.display import display

from physlearn.supervised.regression import BaseRegressor
from physlearn.supervised.utils._data_checks import _n_targets
from physlearn.supervised.utils._definition import (_MULTI_TARGET, _SHAP_TAXONOMY,
                                                    _SHAP_SUMMARY_PLOT_CHOICE)


[docs]@dataclass class ShapInterpret(BaseRegressor): """Interpret a regressor's output with SHAP plots.""" show: bool = field(default=True) def __post_init__(self): self.explainer_type = _SHAP_TAXONOMY[self.regressor_choice] self._validate_display_options() self._validate_regressor_options() self._get_regressor() def _validate_display_options(self): assert isinstance(self.show, bool)
[docs] def fit(self, X, y, index=None, sample_weight=None): """Fit regressor.""" if index is not None and \ sklearn.utils.multiclass.type_of_target(y) in _MULTI_TARGET: super().get_pipeline(y=y.iloc[:, index]) super()._fit(regressor=self.pipe, X=X, y=y.iloc[:, index].values.ravel(order='K'), sample_weight=sample_weight) else: super().get_pipeline(y=y) super()._fit(regressor=self.pipe, X=X, y=y, sample_weight=sample_weight)
[docs] def explainer(self, X): """Compute the importance of each feature for the underlying regressor.""" try: sklearn.utils.validation.check_is_fitted(estimator=self.pipe, attributes='_final_estimator') except AttributeError: print('The pipeline has not been built. Please use the fit method beforehand.') if self.explainer_type == 'tree': explainer = shap.TreeExplainer(model=self.pipe.named_steps['reg'], feature_perturbation='interventional', data=X) shap_values = explainer.shap_values(X=X) elif self.explainer_type == 'linear': explainer = shap.LinearExplainer(model=self.pipe.named_steps['reg'], feature_perturbation='correlation_dependent', masker=X) shap_values = explainer.shap_values(X=X) elif self.explainer_type == 'kernel': explainer = shap.KernelExplainer(model=self.pipe.named_steps['reg'].predict, data=X) shap_values = explainer.shap_values(X=X, l1_reg='aic') return explainer, shap_values
[docs] def summary_plot(self, X, y, plot_type='dot'): """Visualizaion of the feature importance and feature effects.""" assert(plot_type in _SHAP_SUMMARY_PLOT_CHOICE) # Automates single-target slicing y = super()._check_target_index(y=y) for index in range(_n_targets(y)): self.fit(X=X, y=y, index=index) _, shap_values = self.explainer(X=X) shap.summary_plot(shap_values=shap_values, features=X, plot_type=plot_type, feature_names=list(X.columns), show=self.show)
[docs] def force_plot(self, X, y): """Interactive Javascript visualization of Shapley values.""" shap.initjs() # Automates single-target slicing y = super()._check_target_index(y=y) for index in range(_n_targets(y)): self.fit(X=X, y=y, index=index) explainer, shap_values = self.explainer(X=X) force_plot = display(shap.force_plot(base_value=explainer.expected_value, shap_values=shap_values, features=X, plot_cmap=['#52A267','#F0693B'], feature_names=list(X.columns)))
[docs] def dependence_plot(self, X, y, interaction_index='auto', alpha=None, dot_size=None): """Visualization of a feature's effect on a regressor's prediction.""" # Automates single-target slicing y = super()._check_target_index(y=y) for index in range(_n_targets(y)): self.fit(X=X, y=y, index=index) _, shap_values = self.explainer(X=X) shap.dependence_plot(ind='rank(0)', shap_values=shap_values, features=X, feature_names=list(X.columns), cmap=plt.get_cmap('hot'), interaction_index=interaction_index, alpha=alpha, dot_size=dot_size, show=self.show)
[docs] def decision_plot(self, X, y): """Visualization of the additive feature attribution.""" # Automates single-target slicing y = super()._check_target_index(y=y) for index in range(_n_targets(y)): self.fit(X=X, y=y, index=index) explainer, shap_values = self.explainer(X=X) shap.decision_plot(base_value=explainer.expected_value, shap_values=shap_values, feature_names=list(X.columns), show=self.show)