Coverage for src/sensai/evaluation/eval_stats/eval_stats_base.py: 50%

212 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1from abc import ABC, abstractmethod 

2from typing import Generic, TypeVar, List, Union, Dict, Sequence, Optional, Tuple, Callable 

3 

4import numpy as np 

5import pandas as pd 

6from matplotlib import pyplot as plt 

7 

8from ...util.pickle import setstate 

9from ...util.plot import ScatterPlot, HistogramPlot, Plot, HeatMapPlot 

10from ...util.string import ToStringMixin, dict_string 

11from ...vector_model import VectorModel 

12 

13# Note: in the 2020.2 version of PyCharm passing strings to bound is highlighted as error 

14# It does not cause runtime errors and the static type checker ignores the bound anyway, so it does not matter for now. 

15# However, this might cause problems with type checking in the future. Therefore, I moved the definition of TEvalStats 

16# below the definition of EvalStats. Unfortunately, the dependency in generics between EvalStats and Metric 

17# does not allow to define both, TMetric and TEvalStats, properly. For now we have to leave it with the bound as string 

18# and hope for the best in the future 

19TMetric = TypeVar("TMetric", bound="Metric") 

20TVectorModel = TypeVar("TVectorModel", bound=VectorModel) 

21 

22Array = Union[np.ndarray, pd.Series, list] 

23PredictionArray = Union[np.ndarray, pd.Series, pd.DataFrame, list] 

24 

25 

26class EvalStats(Generic[TMetric], ToStringMixin): 

27 def __init__(self, metrics: List[TMetric], additional_metrics: List[TMetric] = None): 

28 if len(metrics) == 0: 

29 raise ValueError("No metrics provided") 

30 self.metrics = metrics 

31 # Implementations of EvalStats will typically provide default metrics, therefore we include 

32 # the possibility for passing additional metrics here 

33 if additional_metrics is not None: 

34 self.metrics = self.metrics + additional_metrics 

35 self.name = None 

36 

37 def set_name(self, name: str): 

38 self.name = name 

39 

40 def add_metric(self, metric: TMetric): 

41 self.metrics.append(metric) 

42 

43 def compute_metric_value(self, metric: TMetric) -> float: 

44 return metric.compute_value_for_eval_stats(self) 

45 

46 def metrics_dict(self) -> Dict[str, float]: 

47 """ 

48 Computes all metrics 

49 

50 :return: a dictionary mapping metric names to values 

51 """ 

52 d = {} 

53 for metric in self.metrics: 

54 d[metric.name] = self.compute_metric_value(metric) 

55 return d 

56 

57 def get_all(self) -> Dict[str, float]: 

58 """Alias for metricsDict; may be deprecated in the future""" 

59 return self.metrics_dict() 

60 

61 def _tostring_object_info(self) -> str: 

62 return dict_string(self.metrics_dict()) 

63 

64 

65TEvalStats = TypeVar("TEvalStats", bound=EvalStats) 

66 

67 

68class Metric(Generic[TEvalStats], ABC): 

69 name: str 

70 

71 def __init__(self, name: str = None, bounds: Optional[Tuple[float, float]] = None): 

72 """ 

73 :param name: the name of the metric; if None use the class' name attribute 

74 :param bounds: the minimum and maximum values the metric can take on (or None if the bounds are not specified) 

75 """ 

76 # this raises an attribute error if a subclass does not specify a name as a static attribute nor as parameter 

77 self.name = name if name is not None else self.__class__.name 

78 self.bounds = bounds 

79 

80 @abstractmethod 

81 def compute_value_for_eval_stats(self, eval_stats: TEvalStats) -> float: 

82 pass 

83 

84 def get_paired_metrics(self) -> List[TMetric]: 

85 """ 

86 Gets a list of metrics that should be considered together with this metric (e.g. for paired visualisations/plots). 

87 The direction of the pairing should be such that if this metric is "x", the other is "y" for x-y type visualisations. 

88 

89 :return: a list of metrics 

90 """ 

91 return [] 

92 

93 def has_finite_bounds(self) -> bool: 

94 return self.bounds is not None and not any((np.isinf(x) for x in self.bounds)) 

95 

96 

97class EvalStatsCollection(Generic[TEvalStats, TMetric], ABC): 

98 def __init__(self, eval_stats_list: List[TEvalStats]): 

99 self.statsList = eval_stats_list 

100 metric_names_set = None 

101 metrics_list = [] 

102 for es in eval_stats_list: 

103 metrics = es.metrics_dict() 

104 current_metric_names_set = set(metrics.keys()) 

105 if metric_names_set is None: 

106 metric_names_set = current_metric_names_set 

107 else: 

108 if metric_names_set != current_metric_names_set: 

109 raise Exception(f"Inconsistent set of metrics in evaluation stats collection: " 

110 f"Got {metric_names_set} for one instance, {current_metric_names_set} for another") 

111 metrics_list.append(metrics) 

112 metric_names = sorted(metrics_list[0].keys()) 

113 self._valuesByMetricName = {metric: [d[metric] for d in metrics_list] for metric in metric_names} 

114 self._metrics: List[TMetric] = eval_stats_list[0].metrics 

115 

116 def get_values(self, metric_name: str): 

117 return self._valuesByMetricName[metric_name] 

118 

119 def get_metric_names(self) -> List[str]: 

120 return list(self._valuesByMetricName.keys()) 

121 

122 def get_metrics(self) -> List[TMetric]: 

123 return self._metrics 

124 

125 def get_metric_by_name(self, name: str) -> Optional[TMetric]: 

126 for m in self._metrics: 

127 if m.name == name: 

128 return m 

129 return None 

130 

131 def has_metric(self, metric: Union[Metric, str]) -> bool: 

132 if type(metric) != str: 

133 metric = metric.name 

134 return metric in self._valuesByMetricName 

135 

136 def agg_metrics_dict(self, agg_fns=(np.mean, np.std)) -> Dict[str, float]: 

137 agg = {} 

138 for metric, values in self._valuesByMetricName.items(): 

139 for agg_fn in agg_fns: 

140 agg[f"{agg_fn.__name__}[{metric}]"] = float(agg_fn(values)) 

141 return agg 

142 

143 def mean_metrics_dict(self) -> Dict[str, float]: 

144 metrics = {metric: np.mean(values) for (metric, values) in self._valuesByMetricName.items()} 

145 return metrics 

146 

147 def plot_distribution(self, metric_name: str, subtitle: Optional[str] = None, bins=None, kde=False, cdf=False, 

148 cdf_complementary=False, stat="proportion", **kwargs) -> plt.Figure: 

149 """ 

150 Plots the distribution of a metric as a histogram 

151 

152 :param metric_name: name of the metric for which to plot the distribution (histogram) across evaluations 

153 :param subtitle: the subtitle to add, if any 

154 :param bins: the histogram bins (number of bins or boundaries); metrics bounds will be used to define the x limits. 

155 If None, use 'auto' bins 

156 :param kde: whether to add a kernel density estimator plot 

157 :param cdf: whether to add the cumulative distribution function (cdf) 

158 :param cdf_complementary: whether to plot, if ``cdf`` is True, the complementary cdf instead of the regular cdf 

159 :param stat: the statistic to compute for each bin ('percent', 'probability'='proportion', 'count', 'frequency' or 'density'), 

160 y-axis value 

161 :param kwargs: additional parameters to pass to seaborn.histplot (see https://seaborn.pydata.org/generated/seaborn.histplot.html) 

162 :return: the plot 

163 """ 

164 # define bins based on metric bounds where available 

165 x_tick = None 

166 if bins is None or type(bins) == int: 

167 metric = self.get_metric_by_name(metric_name) 

168 if metric.bounds == (0, 1): 

169 x_tick = 0.1 

170 if bins is None: 

171 num_bins = 10 if cdf else 20 

172 else: 

173 num_bins = bins 

174 bins = np.linspace(0, 1, num_bins+1) 

175 else: 

176 bins = "auto" 

177 

178 values = self._valuesByMetricName[metric_name] 

179 title = metric_name 

180 if subtitle is not None: 

181 title += "\n" + subtitle 

182 plot = HistogramPlot(values, bins=bins, stat=stat, kde=kde, cdf=cdf, cdf_complementary=cdf_complementary, **kwargs).title(title) 

183 if x_tick is not None: 

184 plot.xtick_major(x_tick) 

185 return plot.fig 

186 

187 def _plot_xy(self, metric_name_x, metric_name_y, plot_factory: Callable[[Sequence, Sequence], Plot], adjust_bounds: bool) -> plt.Figure: 

188 def axlim(bounds): 

189 min_value, max_value = bounds 

190 diff = max_value - min_value 

191 return (min_value - 0.05 * diff, max_value + 0.05 * diff) 

192 

193 x = self._valuesByMetricName[metric_name_x] 

194 y = self._valuesByMetricName[metric_name_y] 

195 plot = plot_factory(x, y) 

196 plot.xlabel(metric_name_x) 

197 plot.ylabel(metric_name_y) 

198 mx = self.get_metric_by_name(metric_name_x) 

199 if adjust_bounds and mx.has_finite_bounds(): 

200 plot.xlim(*axlim(mx.bounds)) 

201 my = self.get_metric_by_name(metric_name_y) 

202 if adjust_bounds and my.has_finite_bounds(): 

203 plot.ylim(*axlim(my.bounds)) 

204 return plot.fig 

205 

206 def plot_scatter(self, metric_name_x: str, metric_name_y: str) -> plt.Figure: 

207 return self._plot_xy(metric_name_x, metric_name_y, ScatterPlot, adjust_bounds=True) 

208 

209 def plot_heat_map(self, metric_name_x: str, metric_name_y: str) -> plt.Figure: 

210 return self._plot_xy(metric_name_x, metric_name_y, HeatMapPlot, adjust_bounds=False) 

211 

212 def to_data_frame(self) -> pd.DataFrame: 

213 """ 

214 :return: a DataFrame with the evaluation metrics from all contained EvalStats objects; 

215 the EvalStats' name field being used as the index if it is set 

216 """ 

217 data = dict(self._valuesByMetricName) 

218 index = [stats.name for stats in self.statsList] 

219 if len([n for n in index if n is not None]) == 0: 

220 index = None 

221 return pd.DataFrame(data, index=index) 

222 

223 def get_global_stats(self) -> TEvalStats: 

224 """ 

225 Alias for `getCombinedEvalStats` 

226 """ 

227 return self.get_combined_eval_stats() 

228 

229 @abstractmethod 

230 def get_combined_eval_stats(self) -> TEvalStats: 

231 """ 

232 :return: an EvalStats object that combines the data from all contained EvalStats objects 

233 """ 

234 pass 

235 

236 def __str__(self): 

237 return f"{self.__class__.__name__}[" + \ 

238 ", ".join([f"{key}={self.agg_metrics_dict()[key]:.4f}" for key in self._valuesByMetricName]) + "]" 

239 

240 

241class PredictionEvalStats(EvalStats[TMetric], ABC): 

242 """ 

243 Collects data for the evaluation of predicted values (including multi-dimensional predictions) 

244 and computes corresponding metrics 

245 """ 

246 def __init__(self, y_predicted: Optional[PredictionArray], y_true: Optional[PredictionArray], 

247 metrics: List[TMetric], additional_metrics: List[TMetric] = None, 

248 weights: Optional[Array] = None): 

249 """ 

250 :param y_predicted: sequence of predicted values, or, in case of multi-dimensional predictions, either a data frame with 

251 one column per dimension or a nested sequence of values 

252 :param y_true: sequence of ground truth labels of same shape as y_predicted 

253 :param metrics: list of metrics to be computed on the provided data 

254 :param additional_metrics: the metrics to additionally compute. This should only be provided if metrics is None 

255 :param weights: weights for each data point contained in `y_predicted` and `y_true` 

256 """ 

257 self.y_true = [] 

258 self.y_predicted = [] 

259 self.weights: Optional[List[float]] = None 

260 self.y_true_multidim = None 

261 self.y_predicted_multidim = None 

262 if y_predicted is not None: 

263 self.add_all(y_predicted=y_predicted, y_true=y_true, weights=weights) 

264 super().__init__(metrics, additional_metrics=additional_metrics) 

265 

266 def __setstate__(self, state): 

267 return setstate(PredictionEvalStats, self, state, new_optional_properties=["weights"]) 

268 

269 def add(self, y_predicted, y_true, weight: Optional[float] = None) -> None: 

270 """ 

271 Adds a single pair of values to the evaluation 

272 

273 :param y_predicted: the value predicted by the model 

274 :param y_true: the true value 

275 """ 

276 self.y_true.append(y_true) 

277 self.y_predicted.append(y_predicted) 

278 if weight is not None: 

279 if self.weights is None: 

280 self.weights = [] 

281 self.weights.append(weight) 

282 

283 def add_all(self, y_predicted: PredictionArray, y_true: PredictionArray, weights: Optional[Array] = None) -> None: 

284 """ 

285 :param y_predicted: sequence of predicted values, or, in case of multi-dimensional predictions, either a data frame with 

286 one column per dimension or a nested sequence of values 

287 :param y_true: sequence of ground truth labels of same shape as y_predicted 

288 :param weights: optional weights of data points 

289 """ 

290 def is_sequence(x): 

291 return isinstance(x, pd.Series) or isinstance(x, list) or isinstance(x, np.ndarray) 

292 

293 if is_sequence(y_predicted) and is_sequence(y_true): 

294 a, b = len(y_predicted), len(y_true) 

295 if a != b: 

296 raise Exception(f"Lengths differ (predicted {a}, truth {b})") 

297 if a > 0: 

298 first_item = y_predicted.iloc[0] if isinstance(y_predicted, pd.Series) else y_predicted[0] 

299 is_nested_sequence = is_sequence(first_item) 

300 if is_nested_sequence: 

301 for y_true_i, y_predicted_i in zip(y_true, y_predicted): 

302 self.add_all(y_predicted=y_predicted_i, y_true=y_true_i) 

303 else: 

304 self.y_true.extend(y_true) 

305 self.y_predicted.extend(y_predicted) 

306 elif isinstance(y_predicted, pd.DataFrame) and isinstance(y_true, pd.DataFrame): 

307 # keep track of multidimensional data (to be used later in getEvalStatsCollection) 

308 y_predicted_multidim = y_predicted.values 

309 y_true_multidim = y_true.values 

310 dim = y_predicted_multidim.shape[1] 

311 if dim != y_true_multidim.shape[1]: 

312 raise Exception("Dimension mismatch") 

313 if self.y_true_multidim is None: 

314 self.y_predicted_multidim = [[] for _ in range(dim)] 

315 self.y_true_multidim = [[] for _ in range(dim)] 

316 if len(self.y_predicted_multidim) != dim: 

317 raise Exception("Dimension mismatch") 

318 for i in range(dim): 

319 self.y_predicted_multidim[i].extend(y_predicted_multidim[:, i]) 

320 self.y_true_multidim[i].extend(y_true_multidim[:, i]) 

321 # convert to flat data for this stats object 

322 y_predicted = y_predicted_multidim.reshape(-1) 

323 y_true = y_true_multidim.reshape(-1) 

324 self.y_true.extend(y_true) 

325 self.y_predicted.extend(y_predicted) 

326 else: 

327 raise Exception(f"Unhandled data types: {type(y_predicted)}, {type(y_true)}") 

328 

329 if weights is not None: 

330 if self.weights is None: 

331 self.weights = [] 

332 assert len(weights) == len(self.y_predicted) - len(self.weights), "Length of weights does not match" 

333 self.weights.extend(weights) 

334 

335 def _tostring_object_info(self) -> str: 

336 return f"{super()._tostring_object_info()}, N={len(self.y_predicted)}" 

337 

338 

339def mean_stats(eval_stats_list: Sequence[EvalStats]) -> Dict[str, float]: 

340 """ 

341 For a list of EvalStats objects compute the mean values of all metrics in a dictionary. 

342 Assumes that all provided EvalStats have the same metrics 

343 """ 

344 dicts = [s.metrics_dict() for s in eval_stats_list] 

345 metrics = dicts[0].keys() 

346 return {m: np.mean([d[m] for d in dicts]) for m in metrics} 

347 

348 

349class EvalStatsPlot(Generic[TEvalStats], ABC): 

350 @abstractmethod 

351 def create_figure(self, eval_stats: TEvalStats, subtitle: str) -> Optional[plt.Figure]: 

352 """ 

353 :param eval_stats: the evaluation stats from which to generate the plot 

354 :param subtitle: the plot's subtitle 

355 :return: the figure or None if this plot is not applicable/cannot be created 

356 """ 

357 pass 

358 

359 def is_applicable(self, eval_stats: TEvalStats) -> bool: 

360 return True