Coverage for src/sensai/feature_importance.py: 33%
139 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import collections
2import copy
3import logging
4import re
5from abc import ABC, abstractmethod
6from typing import Dict, Union, Sequence, List, Tuple, Optional
8import numpy as np
9import pandas as pd
10import seaborn as sns
11from matplotlib import pyplot as plt
13from .data import InputOutputData
14from .evaluation.crossval import VectorModelCrossValidationData
15from .util.deprecation import deprecated
16from .util.plot import MATPLOTLIB_DEFAULT_FIGURE_SIZE
17from .util.string import ToStringMixin
18from .vector_model import VectorModel
20log = logging.getLogger(__name__)
23class FeatureImportance:
24 def __init__(self, feature_importance_dict: Union[Dict[str, float], Dict[str, Dict[str, float]]]):
25 self.feature_importance_dict = feature_importance_dict
26 self._isMultiVar = self._is_dict(next(iter(feature_importance_dict.values())))
28 @staticmethod
29 def _is_dict(x):
30 return hasattr(x, "get")
32 def get_feature_importance_dict(self, predicted_var_name=None) -> Dict[str, float]:
33 if self._isMultiVar:
34 self.feature_importance_dict: Dict[str, Dict[str, float]]
35 if predicted_var_name is not None:
36 return self.feature_importance_dict[predicted_var_name]
37 else:
38 if len(self.feature_importance_dict) > 1:
39 raise ValueError("Must provide predicted variable name (multiple output variables)")
40 else:
41 return next(iter(self.feature_importance_dict.values()))
42 else:
43 return self.feature_importance_dict
45 def get_sorted_tuples(self, predicted_var_name=None, reverse=False) -> List[Tuple[str, float]]:
46 """
47 :param predicted_var_name: the predicted variable name for which to retrieve the sorted feature importance values
48 :param reverse: whether to reverse the order (i.e. descending order of importance values, where the most important feature comes
49 first, rather than ascending order)
50 :return: a sorted list of tuples (feature name, feature importance)
51 """
52 # noinspection PyTypeChecker
53 tuples: List[Tuple[str, float]] = list(self.get_feature_importance_dict(predicted_var_name).items())
54 tuples.sort(key=lambda t: t[1], reverse=reverse)
55 return tuples
57 def plot(self, predicted_var_name=None, sort=True) -> plt.Figure:
58 return plot_feature_importance(self.get_feature_importance_dict(predicted_var_name=predicted_var_name), sort=sort)
60 def get_data_frame(self, predicted_var_name=None) -> pd.DataFrame:
61 """
62 :param predicted_var_name: the predicted variable name
63 :return: a data frame with two columns, "feature" and "importance"
64 """
65 names_and_importance = self.get_sorted_tuples(predicted_var_name=predicted_var_name, reverse=True)
66 return pd.DataFrame(names_and_importance, columns=["feature", "importance"])
69class FeatureImportanceProvider(ABC):
70 """
71 Interface for models that can provide feature importance values
72 """
73 @abstractmethod
74 def get_feature_importance_dict(self) -> Union[Dict[str, float], Dict[str, Dict[str, float]]]:
75 """
76 Gets the feature importance values
78 :return: either a dictionary mapping feature names to importance values or (for models predicting multiple
79 variables (independently)) a dictionary which maps predicted variable names to such dictionaries
80 """
81 pass
83 def get_feature_importance(self) -> FeatureImportance:
84 return FeatureImportance(self.get_feature_importance_dict())
86 @deprecated("Use getFeatureImportanceDict or the high-level interface getFeatureImportance instead.")
87 def get_feature_importances(self) -> Union[Dict[str, float], Dict[str, Dict[str, float]]]:
88 return self.get_feature_importance_dict()
91def plot_feature_importance(feature_importance_dict: Dict[str, float], subtitle: str = None, sort=True) -> plt.Figure:
92 if sort:
93 feature_importance_dict = {k: v for k, v in sorted(feature_importance_dict.items(), key=lambda x: x[1], reverse=True)}
94 num_features = len(feature_importance_dict)
95 default_width, default_height = MATPLOTLIB_DEFAULT_FIGURE_SIZE
96 height = max(default_height, default_height * num_features / 20)
97 fig, ax = plt.subplots(figsize=(default_width, height))
98 sns.barplot(x=list(feature_importance_dict.values()), y=list(feature_importance_dict.keys()), ax=ax)
99 title = "Feature Importance"
100 if subtitle is not None:
101 title += "\n" + subtitle
102 plt.title(title)
103 plt.tight_layout()
104 return fig
107class AggregatedFeatureImportance:
108 """
109 Aggregates feature importance values (e.g. from models implementing FeatureImportanceProvider, such as sklearn's RandomForest
110 models and compatible models from lightgbm, etc.)
111 """
112 def __init__(self, *items: Union[FeatureImportanceProvider, Dict[str, float], Dict[str, Dict[str, float]]],
113 feature_agg_reg_ex: Sequence[str] = (), agg_fn=np.mean):
114 r"""
115 :param items: (optional) initial list of feature importance providers or dictionaries to aggregate; further
116 values can be added via method add
117 :param feature_agg_reg_ex: a sequence of regular expressions describing which feature names to sum as one. Each regex must
118 contain exactly one group. If a regex matches a feature name, the feature importance will be summed under the key
119 of the matched group instead of the full feature name. For example, the regex r"(\w+)_\d+$" will cause "foo_1" and "foo_2"
120 to be summed under "foo" and similarly "bar_1" and "bar_2" to be summed under "bar".
121 """
122 self._agg_dict = None
123 self._is_nested = None
124 self._num_dicts_added = 0
125 self._feature_agg_reg_ex = [re.compile(p) for p in feature_agg_reg_ex]
126 self._agg_fn = agg_fn
127 for item in items:
128 self.add(item)
130 @staticmethod
131 def _is_dict(x):
132 return hasattr(x, "get")
134 def add(self, feature_importance: Union[FeatureImportanceProvider, Dict[str, float], Dict[str, Dict[str, float]]]):
135 """
136 Adds the feature importance values from the given dictionary
138 :param feature_importance: the dictionary obtained via a model's getFeatureImportances method
139 """
140 if isinstance(feature_importance, FeatureImportanceProvider):
141 feature_importance = feature_importance.get_feature_importance_dict()
142 if self._is_nested is None:
143 self._is_nested = self._is_dict(next(iter(feature_importance.values())))
144 if self._is_nested:
145 if self._agg_dict is None:
146 self._agg_dict = collections.defaultdict(lambda: collections.defaultdict(list))
147 for targetName, d in feature_importance.items():
148 d: dict
149 for featureName, value in d.items():
150 self._agg_dict[targetName][self._agg_feature_name(featureName)].append(value)
151 else:
152 if self._agg_dict is None:
153 self._agg_dict = collections.defaultdict(list)
154 for featureName, value in feature_importance.items():
155 self._agg_dict[self._agg_feature_name(featureName)].append(value)
156 self._num_dicts_added += 1
158 def _agg_feature_name(self, feature_name: str):
159 for regex in self._feature_agg_reg_ex:
160 m = regex.match(feature_name)
161 if m is not None:
162 return m.group(1)
163 return feature_name
165 def get_aggregated_feature_importance_dict(self) -> Union[Dict[str, float], Dict[str, Dict[str, float]]]:
166 def aggregate(d: dict):
167 return {k: self._agg_fn(l) for k, l in d.items()}
169 if self._is_nested:
170 return {k: aggregate(d) for k, d in self._agg_dict.items()}
171 else:
172 return aggregate(self._agg_dict)
174 def get_aggregated_feature_importance(self) -> FeatureImportance:
175 return FeatureImportance(self.get_aggregated_feature_importance_dict())
178def compute_permutation_feature_importance_dict(model, io_data: InputOutputData, scoring, num_repeats: int, random_state,
179 exclude_input_preprocessors=False, num_jobs=None):
180 from sklearn.inspection import permutation_importance
181 if exclude_input_preprocessors:
182 inputs = model.compute_model_inputs(io_data.inputs)
183 model = copy.copy(model)
184 model.remove_input_preprocessors()
185 else:
186 inputs = io_data.inputs
187 feature_names = inputs.columns
188 pi = permutation_importance(model, inputs, io_data.outputs, n_repeats=num_repeats, random_state=random_state, scoring=scoring,
189 n_jobs=num_jobs)
190 importance_values = pi.importances_mean
191 assert len(importance_values) == len(feature_names)
192 feature_importance_dict = dict(zip(feature_names, importance_values))
193 return feature_importance_dict
196class AggregatedPermutationFeatureImportance(ToStringMixin):
197 def __init__(self, aggregated_feature_importance: AggregatedFeatureImportance, scoring, num_repeats=5, random_seed=42,
198 exclude_model_input_preprocessors=False, num_jobs: Optional[int] = None):
199 """
200 :param aggregated_feature_importance: the object in which to aggregate the feature importance (to which no feature importance
201 values should have yet been added)
202 :param scoring: the scoring method; see https://scikit-learn.org/stable/modules/model_evaluation.html; e.g. "r2" for regression or
203 "accuracy" for classification
204 :param num_repeats: the number of data permutations to apply for each model
205 :param random_seed: the random seed for shuffling the data
206 :param exclude_model_input_preprocessors: whether to exclude model input preprocessors, such that the
207 feature importance will be reported on the transformed inputs that are actually fed to the model rather than the original
208 inputs.
209 Enabling this can, for example, help save time in cases where the input preprocessors discard many of the raw input
210 columns, but it may not be a good idea of the preprocessors generate multiple columns from the original input columns.
211 :param num_jobs:
212 Number of jobs to run in parallel. Each separate model-data permutation feature importance computation is parallelised over
213 the columns. `None` means 1 unless in a :obj:`joblib.parallel_backend` context.
214 `-1` means using all processors.
215 """
216 self._agg = aggregated_feature_importance
217 self.scoring = scoring
218 self.numRepeats = num_repeats
219 self.randomSeed = random_seed
220 self.excludeModelInputPreprocessors = exclude_model_input_preprocessors
221 self.numJobs = num_jobs
223 def add(self, model: VectorModel, io_data: InputOutputData):
224 feature_importance_dict = compute_permutation_feature_importance_dict(model, io_data, self.scoring, num_repeats=self.numRepeats,
225 random_state=self.randomSeed, exclude_input_preprocessors=self.excludeModelInputPreprocessors, num_jobs=self.numJobs)
226 self._agg.add(feature_importance_dict)
228 def add_cross_validation_data(self, cross_val_data: VectorModelCrossValidationData):
229 if cross_val_data.trained_models is None:
230 raise ValueError("No models in cross-validation data; enable model collection during cross-validation")
231 for i, (model, evalData) in enumerate(zip(cross_val_data.trained_models, cross_val_data.eval_data_list), start=1):
232 log.info(f"Computing permutation feature importance for model #{i}/{len(cross_val_data.trained_models)}")
233 self.add(model, evalData.io_data)
235 def get_feature_importance(self) -> FeatureImportance:
236 return self._agg.get_aggregated_feature_importance()