Coverage for src/sensai/evaluation/eval_stats/eval_stats_base.py: 52%

196 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from abc import ABC, abstractmethod 

2from typing import Generic, TypeVar, List, Union, Dict, Sequence, Optional, Tuple, Callable 

3 

4import numpy as np 

5import pandas as pd 

6from matplotlib import pyplot as plt 

7 

8from ...util.plot import ScatterPlot, HistogramPlot, Plot, HeatMapPlot 

9from ...util.string import ToStringMixin, dict_string 

10from ...vector_model import VectorModel 

11 

12# Note: in the 2020.2 version of PyCharm passing strings to bound is highlighted as error 

13# It does not cause runtime errors and the static type checker ignores the bound anyway, so it does not matter for now. 

14# However, this might cause problems with type checking in the future. Therefore, I moved the definition of TEvalStats 

15# below the definition of EvalStats. Unfortunately, the dependency in generics between EvalStats and Metric 

16# does not allow to define both, TMetric and TEvalStats, properly. For now we have to leave it with the bound as string 

17# and hope for the best in the future 

18TMetric = TypeVar("TMetric", bound="Metric") 

19TVectorModel = TypeVar("TVectorModel", bound=VectorModel) 

20 

21PredictionArray = Union[np.ndarray, pd.Series, pd.DataFrame, list] 

22 

23 

24class EvalStats(Generic[TMetric], ToStringMixin): 

25 def __init__(self, metrics: List[TMetric], additional_metrics: List[TMetric] = None): 

26 if len(metrics) == 0: 

27 raise ValueError("No metrics provided") 

28 self.metrics = metrics 

29 # Implementations of EvalStats will typically provide default metrics, therefore we include 

30 # the possibility for passing additional metrics here 

31 if additional_metrics is not None: 

32 self.metrics = self.metrics + additional_metrics 

33 self.name = None 

34 

35 def set_name(self, name: str): 

36 self.name = name 

37 

38 def add_metric(self, metric: TMetric): 

39 self.metrics.append(metric) 

40 

41 def compute_metric_value(self, metric: TMetric) -> float: 

42 return metric.compute_value_for_eval_stats(self) 

43 

44 def metrics_dict(self) -> Dict[str, float]: 

45 """ 

46 Computes all metrics 

47 

48 :return: a dictionary mapping metric names to values 

49 """ 

50 d = {} 

51 for metric in self.metrics: 

52 d[metric.name] = self.compute_metric_value(metric) 

53 return d 

54 

55 def get_all(self) -> Dict[str, float]: 

56 """Alias for metricsDict; may be deprecated in the future""" 

57 return self.metrics_dict() 

58 

59 def _tostring_object_info(self) -> str: 

60 return dict_string(self.metrics_dict()) 

61 

62 

63TEvalStats = TypeVar("TEvalStats", bound=EvalStats) 

64 

65 

66class Metric(Generic[TEvalStats], ABC): 

67 name: str 

68 

69 def __init__(self, name: str = None, bounds: Optional[Tuple[float, float]] = None): 

70 """ 

71 :param name: the name of the metric; if None use the class' name attribute 

72 :param bounds: the minimum and maximum values the metric can take on (or None if the bounds are not specified) 

73 """ 

74 # this raises an attribute error if a subclass does not specify a name as a static attribute nor as parameter 

75 self.name = name if name is not None else self.__class__.name 

76 self.bounds = bounds 

77 

78 @abstractmethod 

79 def compute_value_for_eval_stats(self, eval_stats: TEvalStats) -> float: 

80 pass 

81 

82 def get_paired_metrics(self) -> List[TMetric]: 

83 """ 

84 Gets a list of metrics that should be considered together with this metric (e.g. for paired visualisations/plots). 

85 The direction of the pairing should be such that if this metric is "x", the other is "y" for x-y type visualisations. 

86 

87 :return: a list of metrics 

88 """ 

89 return [] 

90 

91 def has_finite_bounds(self) -> bool: 

92 return self.bounds is not None and not any((np.isinf(x) for x in self.bounds)) 

93 

94 

95class EvalStatsCollection(Generic[TEvalStats, TMetric], ABC): 

96 def __init__(self, eval_stats_list: List[TEvalStats]): 

97 self.statsList = eval_stats_list 

98 metric_names_set = None 

99 metrics_list = [] 

100 for es in eval_stats_list: 

101 metrics = es.metrics_dict() 

102 current_metric_names_set = set(metrics.keys()) 

103 if metric_names_set is None: 

104 metric_names_set = current_metric_names_set 

105 else: 

106 if metric_names_set != current_metric_names_set: 

107 raise Exception(f"Inconsistent set of metrics in evaluation stats collection: " 

108 f"Got {metric_names_set} for one instance, {current_metric_names_set} for another") 

109 metrics_list.append(metrics) 

110 metric_names = sorted(metrics_list[0].keys()) 

111 self._valuesByMetricName = {metric: [d[metric] for d in metrics_list] for metric in metric_names} 

112 self._metrics: List[TMetric] = eval_stats_list[0].metrics 

113 

114 def get_values(self, metric_name: str): 

115 return self._valuesByMetricName[metric_name] 

116 

117 def get_metric_names(self) -> List[str]: 

118 return list(self._valuesByMetricName.keys()) 

119 

120 def get_metrics(self) -> List[TMetric]: 

121 return self._metrics 

122 

123 def get_metric_by_name(self, name: str) -> Optional[TMetric]: 

124 for m in self._metrics: 

125 if m.name == name: 

126 return m 

127 return None 

128 

129 def has_metric(self, metric: Union[Metric, str]) -> bool: 

130 if type(metric) != str: 

131 metric = metric.name 

132 return metric in self._valuesByMetricName 

133 

134 def agg_metrics_dict(self, agg_fns=(np.mean, np.std)) -> Dict[str, float]: 

135 agg = {} 

136 for metric, values in self._valuesByMetricName.items(): 

137 for agg_fn in agg_fns: 

138 agg[f"{agg_fn.__name__}[{metric}]"] = float(agg_fn(values)) 

139 return agg 

140 

141 def mean_metrics_dict(self) -> Dict[str, float]: 

142 metrics = {metric: np.mean(values) for (metric, values) in self._valuesByMetricName.items()} 

143 return metrics 

144 

145 def plot_distribution(self, metric_name: str, subtitle: Optional[str] = None, bins=None, kde=False, cdf=False, 

146 cdf_complementary=False, stat="proportion", **kwargs) -> plt.Figure: 

147 """ 

148 Plots the distribution of a metric as a histogram 

149 

150 :param metric_name: name of the metric for which to plot the distribution (histogram) across evaluations 

151 :param subtitle: the subtitle to add, if any 

152 :param bins: the histogram bins (number of bins or boundaries); metrics bounds will be used to define the x limits. 

153 If None, use 'auto' bins 

154 :param kde: whether to add a kernel density estimator plot 

155 :param cdf: whether to add the cumulative distribution function (cdf) 

156 :param cdf_complementary: whether to plot, if ``cdf`` is True, the complementary cdf instead of the regular cdf 

157 :param stat: the statistic to compute for each bin ('percent', 'probability'='proportion', 'count', 'frequency' or 'density'), 

158 y-axis value 

159 :param kwargs: additional parameters to pass to seaborn.histplot (see https://seaborn.pydata.org/generated/seaborn.histplot.html) 

160 :return: the plot 

161 """ 

162 # define bins based on metric bounds where available 

163 x_tick = None 

164 if bins is None or type(bins) == int: 

165 metric = self.get_metric_by_name(metric_name) 

166 if metric.bounds == (0, 1): 

167 x_tick = 0.1 

168 if bins is None: 

169 num_bins = 10 if cdf else 20 

170 else: 

171 num_bins = bins 

172 bins = np.linspace(0, 1, num_bins+1) 

173 else: 

174 bins = "auto" 

175 

176 values = self._valuesByMetricName[metric_name] 

177 title = metric_name 

178 if subtitle is not None: 

179 title += "\n" + subtitle 

180 plot = HistogramPlot(values, bins=bins, stat=stat, kde=kde, cdf=cdf, cdf_complementary=cdf_complementary, **kwargs).title(title) 

181 if x_tick is not None: 

182 plot.xtick_major(x_tick) 

183 return plot.fig 

184 

185 def _plot_xy(self, metric_name_x, metric_name_y, plot_factory: Callable[[Sequence, Sequence], Plot], adjust_bounds: bool) -> plt.Figure: 

186 def axlim(bounds): 

187 min_value, max_value = bounds 

188 diff = max_value - min_value 

189 return (min_value - 0.05 * diff, max_value + 0.05 * diff) 

190 

191 x = self._valuesByMetricName[metric_name_x] 

192 y = self._valuesByMetricName[metric_name_y] 

193 plot = plot_factory(x, y) 

194 plot.xlabel(metric_name_x) 

195 plot.ylabel(metric_name_y) 

196 mx = self.get_metric_by_name(metric_name_x) 

197 if adjust_bounds and mx.has_finite_bounds(): 

198 plot.xlim(*axlim(mx.bounds)) 

199 my = self.get_metric_by_name(metric_name_y) 

200 if adjust_bounds and my.has_finite_bounds(): 

201 plot.ylim(*axlim(my.bounds)) 

202 return plot.fig 

203 

204 def plot_scatter(self, metric_name_x: str, metric_name_y: str) -> plt.Figure: 

205 return self._plot_xy(metric_name_x, metric_name_y, ScatterPlot, adjust_bounds=True) 

206 

207 def plot_heat_map(self, metric_name_x: str, metric_name_y: str) -> plt.Figure: 

208 return self._plot_xy(metric_name_x, metric_name_y, HeatMapPlot, adjust_bounds=False) 

209 

210 def to_data_frame(self) -> pd.DataFrame: 

211 """ 

212 :return: a DataFrame with the evaluation metrics from all contained EvalStats objects; 

213 the EvalStats' name field being used as the index if it is set 

214 """ 

215 data = dict(self._valuesByMetricName) 

216 index = [stats.name for stats in self.statsList] 

217 if len([n for n in index if n is not None]) == 0: 

218 index = None 

219 return pd.DataFrame(data, index=index) 

220 

221 def get_global_stats(self) -> TEvalStats: 

222 """ 

223 Alias for `getCombinedEvalStats` 

224 """ 

225 return self.get_combined_eval_stats() 

226 

227 @abstractmethod 

228 def get_combined_eval_stats(self) -> TEvalStats: 

229 """ 

230 :return: an EvalStats object that combines the data from all contained EvalStats objects 

231 """ 

232 pass 

233 

234 def __str__(self): 

235 return f"{self.__class__.__name__}[" + \ 

236 ", ".join([f"{key}={self.agg_metrics_dict()[key]:.4f}" for key in self._valuesByMetricName]) + "]" 

237 

238 

239class PredictionEvalStats(EvalStats[TMetric], ABC): 

240 """ 

241 Collects data for the evaluation of predicted values (including multi-dimensional predictions) 

242 and computes corresponding metrics 

243 """ 

244 def __init__(self, y_predicted: Optional[PredictionArray], y_true: Optional[PredictionArray], 

245 metrics: List[TMetric], additional_metrics: List[TMetric] = None): 

246 """ 

247 :param y_predicted: sequence of predicted values, or, in case of multi-dimensional predictions, either a data frame with 

248 one column per dimension or a nested sequence of values 

249 :param y_true: sequence of ground truth labels of same shape as y_predicted 

250 :param metrics: list of metrics to be computed on the provided data 

251 :param additional_metrics: the metrics to additionally compute. This should only be provided if metrics is None 

252 """ 

253 self.y_true = [] 

254 self.y_predicted = [] 

255 self.y_true_multidim = None 

256 self.y_predicted_multidim = None 

257 if y_predicted is not None: 

258 self.add_all(y_predicted, y_true) 

259 super().__init__(metrics, additional_metrics=additional_metrics) 

260 

261 def add(self, y_predicted, y_true) -> None: 

262 """ 

263 Adds a single pair of values to the evaluation 

264 

265 :param y_predicted: the value predicted by the model 

266 :param y_true: the true value 

267 """ 

268 self.y_true.append(y_true) 

269 self.y_predicted.append(y_predicted) 

270 

271 def add_all(self, y_predicted: PredictionArray, y_true: PredictionArray) -> None: 

272 """ 

273 :param y_predicted: sequence of predicted values, or, in case of multi-dimensional predictions, either a data frame with 

274 one column per dimension or a nested sequence of values 

275 :param y_true: sequence of ground truth labels of same shape as y_predicted 

276 """ 

277 def is_sequence(x): 

278 return isinstance(x, pd.Series) or isinstance(x, list) or isinstance(x, np.ndarray) 

279 

280 if is_sequence(y_predicted) and is_sequence(y_true): 

281 a, b = len(y_predicted), len(y_true) 

282 if a != b: 

283 raise Exception(f"Lengths differ (predicted {a}, truth {b})") 

284 if a > 0: 

285 first_item = y_predicted.iloc[0] if isinstance(y_predicted, pd.Series) else y_predicted[0] 

286 is_nested_sequence = is_sequence(first_item) 

287 if is_nested_sequence: 

288 for y_true_i, y_predicted_i in zip(y_true, y_predicted): 

289 self.add_all(y_predicted=y_predicted_i, y_true=y_true_i) 

290 else: 

291 self.y_true.extend(y_true) 

292 self.y_predicted.extend(y_predicted) 

293 elif isinstance(y_predicted, pd.DataFrame) and isinstance(y_true, pd.DataFrame): 

294 # keep track of multidimensional data (to be used later in getEvalStatsCollection) 

295 y_predicted_multidim = y_predicted.values 

296 y_true_multidim = y_true.values 

297 dim = y_predicted_multidim.shape[1] 

298 if dim != y_true_multidim.shape[1]: 

299 raise Exception("Dimension mismatch") 

300 if self.y_true_multidim is None: 

301 self.y_predicted_multidim = [[] for _ in range(dim)] 

302 self.y_true_multidim = [[] for _ in range(dim)] 

303 if len(self.y_predicted_multidim) != dim: 

304 raise Exception("Dimension mismatch") 

305 for i in range(dim): 

306 self.y_predicted_multidim[i].extend(y_predicted_multidim[:, i]) 

307 self.y_true_multidim[i].extend(y_true_multidim[:, i]) 

308 # convert to flat data for this stats object 

309 y_predicted = y_predicted_multidim.reshape(-1) 

310 y_true = y_true_multidim.reshape(-1) 

311 self.y_true.extend(y_true) 

312 self.y_predicted.extend(y_predicted) 

313 else: 

314 raise Exception(f"Unhandled data types: {type(y_predicted)}, {type(y_true)}") 

315 

316 def _tostring_object_info(self) -> str: 

317 return f"{super()._tostring_object_info()}, N={len(self.y_predicted)}" 

318 

319 

320def mean_stats(eval_stats_list: Sequence[EvalStats]) -> Dict[str, float]: 

321 """ 

322 For a list of EvalStats objects compute the mean values of all metrics in a dictionary. 

323 Assumes that all provided EvalStats have the same metrics 

324 """ 

325 dicts = [s.metrics_dict() for s in eval_stats_list] 

326 metrics = dicts[0].keys() 

327 return {m: np.mean([d[m] for d in dicts]) for m in metrics} 

328 

329 

330class EvalStatsPlot(Generic[TEvalStats], ABC): 

331 @abstractmethod 

332 def create_figure(self, eval_stats: TEvalStats, subtitle: str) -> Optional[plt.Figure]: 

333 """ 

334 :param eval_stats: the evaluation stats from which to generate the plot 

335 :param subtitle: the plot's subtitle 

336 :return: the figure or None if this plot is not applicable/cannot be created 

337 """ 

338 pass