Coverage for src/sensai/feature_importance.py: 33%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import collections 

2import copy 

3import logging 

4import re 

5from abc import ABC, abstractmethod 

6from typing import Dict, Union, Sequence, List, Tuple, Optional 

7 

8import numpy as np 

9import pandas as pd 

10import seaborn as sns 

11from matplotlib import pyplot as plt 

12 

13from .data import InputOutputData 

14from .evaluation.crossval import VectorModelCrossValidationData 

15from .util.deprecation import deprecated 

16from .util.plot import MATPLOTLIB_DEFAULT_FIGURE_SIZE 

17from .util.string import ToStringMixin 

18from .vector_model import VectorModel 

19 

20log = logging.getLogger(__name__) 

21 

22 

23class FeatureImportance: 

24 def __init__(self, feature_importance_dict: Union[Dict[str, float], Dict[str, Dict[str, float]]]): 

25 self.feature_importance_dict = feature_importance_dict 

26 self._isMultiVar = self._is_dict(next(iter(feature_importance_dict.values()))) 

27 

28 @staticmethod 

29 def _is_dict(x): 

30 return hasattr(x, "get") 

31 

32 def get_feature_importance_dict(self, predicted_var_name=None) -> Dict[str, float]: 

33 if self._isMultiVar: 

34 self.feature_importance_dict: Dict[str, Dict[str, float]] 

35 if predicted_var_name is not None: 

36 return self.feature_importance_dict[predicted_var_name] 

37 else: 

38 if len(self.feature_importance_dict) > 1: 

39 raise ValueError("Must provide predicted variable name (multiple output variables)") 

40 else: 

41 return next(iter(self.feature_importance_dict.values())) 

42 else: 

43 return self.feature_importance_dict 

44 

45 def get_sorted_tuples(self, predicted_var_name=None, reverse=False) -> List[Tuple[str, float]]: 

46 """ 

47 :param predicted_var_name: the predicted variable name for which to retrieve the sorted feature importance values 

48 :param reverse: whether to reverse the order (i.e. descending order of importance values, where the most important feature comes 

49 first, rather than ascending order) 

50 :return: a sorted list of tuples (feature name, feature importance) 

51 """ 

52 # noinspection PyTypeChecker 

53 tuples: List[Tuple[str, float]] = list(self.get_feature_importance_dict(predicted_var_name).items()) 

54 tuples.sort(key=lambda t: t[1], reverse=reverse) 

55 return tuples 

56 

57 def plot(self, predicted_var_name=None, sort=True) -> plt.Figure: 

58 return plot_feature_importance(self.get_feature_importance_dict(predicted_var_name=predicted_var_name), sort=sort) 

59 

60 def get_data_frame(self, predicted_var_name=None) -> pd.DataFrame: 

61 """ 

62 :param predicted_var_name: the predicted variable name 

63 :return: a data frame with two columns, "feature" and "importance" 

64 """ 

65 names_and_importance = self.get_sorted_tuples(predicted_var_name=predicted_var_name, reverse=True) 

66 return pd.DataFrame(names_and_importance, columns=["feature", "importance"]) 

67 

68 

69class FeatureImportanceProvider(ABC): 

70 """ 

71 Interface for models that can provide feature importance values 

72 """ 

73 @abstractmethod 

74 def get_feature_importance_dict(self) -> Union[Dict[str, float], Dict[str, Dict[str, float]]]: 

75 """ 

76 Gets the feature importance values 

77 

78 :return: either a dictionary mapping feature names to importance values or (for models predicting multiple 

79 variables (independently)) a dictionary which maps predicted variable names to such dictionaries 

80 """ 

81 pass 

82 

83 def get_feature_importance(self) -> FeatureImportance: 

84 return FeatureImportance(self.get_feature_importance_dict()) 

85 

86 @deprecated("Use getFeatureImportanceDict or the high-level interface getFeatureImportance instead.") 

87 def get_feature_importances(self) -> Union[Dict[str, float], Dict[str, Dict[str, float]]]: 

88 return self.get_feature_importance_dict() 

89 

90 

91def plot_feature_importance(feature_importance_dict: Dict[str, float], subtitle: str = None, sort=True) -> plt.Figure: 

92 if sort: 

93 feature_importance_dict = {k: v for k, v in sorted(feature_importance_dict.items(), key=lambda x: x[1], reverse=True)} 

94 num_features = len(feature_importance_dict) 

95 default_width, default_height = MATPLOTLIB_DEFAULT_FIGURE_SIZE 

96 height = max(default_height, default_height * num_features / 20) 

97 fig, ax = plt.subplots(figsize=(default_width, height)) 

98 sns.barplot(x=list(feature_importance_dict.values()), y=list(feature_importance_dict.keys()), ax=ax) 

99 title = "Feature Importance" 

100 if subtitle is not None: 

101 title += "\n" + subtitle 

102 plt.title(title) 

103 plt.tight_layout() 

104 return fig 

105 

106 

107class AggregatedFeatureImportance: 

108 """ 

109 Aggregates feature importance values (e.g. from models implementing FeatureImportanceProvider, such as sklearn's RandomForest 

110 models and compatible models from lightgbm, etc.) 

111 """ 

112 def __init__(self, *items: Union[FeatureImportanceProvider, Dict[str, float], Dict[str, Dict[str, float]]], 

113 feature_agg_reg_ex: Sequence[str] = (), agg_fn=np.mean): 

114 r""" 

115 :param items: (optional) initial list of feature importance providers or dictionaries to aggregate; further 

116 values can be added via method add 

117 :param feature_agg_reg_ex: a sequence of regular expressions describing which feature names to sum as one. Each regex must 

118 contain exactly one group. If a regex matches a feature name, the feature importance will be summed under the key 

119 of the matched group instead of the full feature name. For example, the regex r"(\w+)_\d+$" will cause "foo_1" and "foo_2" 

120 to be summed under "foo" and similarly "bar_1" and "bar_2" to be summed under "bar". 

121 """ 

122 self._agg_dict = None 

123 self._is_nested = None 

124 self._num_dicts_added = 0 

125 self._feature_agg_reg_ex = [re.compile(p) for p in feature_agg_reg_ex] 

126 self._agg_fn = agg_fn 

127 for item in items: 

128 self.add(item) 

129 

130 @staticmethod 

131 def _is_dict(x): 

132 return hasattr(x, "get") 

133 

134 def add(self, feature_importance: Union[FeatureImportanceProvider, Dict[str, float], Dict[str, Dict[str, float]]]): 

135 """ 

136 Adds the feature importance values from the given dictionary 

137 

138 :param feature_importance: the dictionary obtained via a model's getFeatureImportances method 

139 """ 

140 if isinstance(feature_importance, FeatureImportanceProvider): 

141 feature_importance = feature_importance.get_feature_importance_dict() 

142 if self._is_nested is None: 

143 self._is_nested = self._is_dict(next(iter(feature_importance.values()))) 

144 if self._is_nested: 

145 if self._agg_dict is None: 

146 self._agg_dict = collections.defaultdict(lambda: collections.defaultdict(list)) 

147 for targetName, d in feature_importance.items(): 

148 d: dict 

149 for featureName, value in d.items(): 

150 self._agg_dict[targetName][self._agg_feature_name(featureName)].append(value) 

151 else: 

152 if self._agg_dict is None: 

153 self._agg_dict = collections.defaultdict(list) 

154 for featureName, value in feature_importance.items(): 

155 self._agg_dict[self._agg_feature_name(featureName)].append(value) 

156 self._num_dicts_added += 1 

157 

158 def _agg_feature_name(self, feature_name: str): 

159 for regex in self._feature_agg_reg_ex: 

160 m = regex.match(feature_name) 

161 if m is not None: 

162 return m.group(1) 

163 return feature_name 

164 

165 def get_aggregated_feature_importance_dict(self) -> Union[Dict[str, float], Dict[str, Dict[str, float]]]: 

166 def aggregate(d: dict): 

167 return {k: self._agg_fn(l) for k, l in d.items()} 

168 

169 if self._is_nested: 

170 return {k: aggregate(d) for k, d in self._agg_dict.items()} 

171 else: 

172 return aggregate(self._agg_dict) 

173 

174 def get_aggregated_feature_importance(self) -> FeatureImportance: 

175 return FeatureImportance(self.get_aggregated_feature_importance_dict()) 

176 

177 

178def compute_permutation_feature_importance_dict(model, io_data: InputOutputData, scoring, num_repeats: int, random_state, 

179 exclude_input_preprocessors=False, num_jobs=None): 

180 from sklearn.inspection import permutation_importance 

181 if exclude_input_preprocessors: 

182 inputs = model.compute_model_inputs(io_data.inputs) 

183 model = copy.copy(model) 

184 model.remove_input_preprocessors() 

185 else: 

186 inputs = io_data.inputs 

187 feature_names = inputs.columns 

188 pi = permutation_importance(model, inputs, io_data.outputs, n_repeats=num_repeats, random_state=random_state, scoring=scoring, 

189 n_jobs=num_jobs) 

190 importance_values = pi.importances_mean 

191 assert len(importance_values) == len(feature_names) 

192 feature_importance_dict = dict(zip(feature_names, importance_values)) 

193 return feature_importance_dict 

194 

195 

196class AggregatedPermutationFeatureImportance(ToStringMixin): 

197 def __init__(self, aggregated_feature_importance: AggregatedFeatureImportance, scoring, num_repeats=5, random_seed=42, 

198 exclude_model_input_preprocessors=False, num_jobs: Optional[int] = None): 

199 """ 

200 :param aggregated_feature_importance: the object in which to aggregate the feature importance (to which no feature importance 

201 values should have yet been added) 

202 :param scoring: the scoring method; see https://scikit-learn.org/stable/modules/model_evaluation.html; e.g. "r2" for regression or 

203 "accuracy" for classification 

204 :param num_repeats: the number of data permutations to apply for each model 

205 :param random_seed: the random seed for shuffling the data 

206 :param exclude_model_input_preprocessors: whether to exclude model input preprocessors, such that the 

207 feature importance will be reported on the transformed inputs that are actually fed to the model rather than the original 

208 inputs. 

209 Enabling this can, for example, help save time in cases where the input preprocessors discard many of the raw input 

210 columns, but it may not be a good idea of the preprocessors generate multiple columns from the original input columns. 

211 :param num_jobs: 

212 Number of jobs to run in parallel. Each separate model-data permutation feature importance computation is parallelised over 

213 the columns. `None` means 1 unless in a :obj:`joblib.parallel_backend` context. 

214 `-1` means using all processors. 

215 """ 

216 self._agg = aggregated_feature_importance 

217 self.scoring = scoring 

218 self.numRepeats = num_repeats 

219 self.randomSeed = random_seed 

220 self.excludeModelInputPreprocessors = exclude_model_input_preprocessors 

221 self.numJobs = num_jobs 

222 

223 def add(self, model: VectorModel, io_data: InputOutputData): 

224 feature_importance_dict = compute_permutation_feature_importance_dict(model, io_data, self.scoring, num_repeats=self.numRepeats, 

225 random_state=self.randomSeed, exclude_input_preprocessors=self.excludeModelInputPreprocessors, num_jobs=self.numJobs) 

226 self._agg.add(feature_importance_dict) 

227 

228 def add_cross_validation_data(self, cross_val_data: VectorModelCrossValidationData): 

229 if cross_val_data.trained_models is None: 

230 raise ValueError("No models in cross-validation data; enable model collection during cross-validation") 

231 for i, (model, evalData) in enumerate(zip(cross_val_data.trained_models, cross_val_data.eval_data_list), start=1): 

232 log.info(f"Computing permutation feature importance for model #{i}/{len(cross_val_data.trained_models)}") 

233 self.add(model, evalData.io_data) 

234 

235 def get_feature_importance(self) -> FeatureImportance: 

236 return self._agg.get_aggregated_feature_importance()