Coverage for src/sensai/evaluation/eval_stats/eval_stats_regression.py: 52%

222 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import logging 

2from abc import abstractmethod, ABC 

3from typing import List, Sequence, Optional 

4 

5import numpy as np 

6from matplotlib import pyplot as plt 

7from matplotlib.colors import LinearSegmentedColormap 

8from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score 

9 

10from . import BinaryClassificationMetric 

11from .eval_stats_base import PredictionEvalStats, Metric, EvalStatsCollection, PredictionArray, EvalStatsPlot, Array 

12from ...util import kwarg_if_not_none 

13from ...util.plot import HistogramPlot 

14from ...vector_model import VectorRegressionModel, InputOutputData 

15 

16log = logging.getLogger(__name__) 

17 

18 

19class RegressionMetric(Metric["RegressionEvalStats"], ABC): 

20 def compute_value_for_eval_stats(self, eval_stats: "RegressionEvalStats"): 

21 weights = np.array(eval_stats.weights) if eval_stats.weights is not None else None 

22 return self.compute_value(np.array(eval_stats.y_true), np.array(eval_stats.y_predicted), 

23 model=eval_stats.model, 

24 io_data=eval_stats.ioData, 

25 **kwarg_if_not_none("weights", weights)) 

26 

27 @abstractmethod 

28 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, 

29 model: VectorRegressionModel = None, 

30 io_data: InputOutputData = None, 

31 weights: Optional[np.ndarray] = None): 

32 pass 

33 

34 @classmethod 

35 def compute_errors(cls, y_true: np.ndarray, y_predicted: np.ndarray): 

36 return y_predicted - y_true 

37 

38 @classmethod 

39 def compute_abs_errors(cls, y_true: np.ndarray, y_predicted: np.ndarray): 

40 return np.abs(cls.compute_errors(y_true, y_predicted)) 

41 

42 

43class RegressionMetricMAE(RegressionMetric): 

44 name = "MAE" 

45 

46 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

47 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

48 return mean_absolute_error(y_true, y_predicted, sample_weight=weights) 

49 

50 

51class RegressionMetricMSE(RegressionMetric): 

52 name = "MSE" 

53 

54 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

55 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

56 return mean_squared_error(y_true, y_predicted, sample_weight=weights) 

57 

58 

59class RegressionMetricRMSE(RegressionMetric): 

60 name = "RMSE" 

61 

62 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

63 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

64 return np.sqrt(mean_squared_error(y_true, y_predicted, sample_weight=weights)) 

65 

66 

67class RegressionMetricRRSE(RegressionMetric): 

68 name = "RRSE" 

69 

70 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

71 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

72 r2 = r2_score(y_true, y_predicted, sample_weight=weights) 

73 return np.sqrt(1 - r2) 

74 

75 

76class RegressionMetricR2(RegressionMetric): 

77 name = "R2" 

78 

79 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

80 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

81 return r2_score(y_true, y_predicted, sample_weight=weights) 

82 

83 

84class RegressionMetricPCC(RegressionMetric): 

85 """ 

86 Pearson's correlation coefficient, aka Pearson's R. 

87 This metric does not consider sample weights. 

88 """ 

89 name = "PCC" 

90 

91 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

92 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

93 cov = np.cov([y_true, y_predicted]) 

94 return cov[0][1] / np.sqrt(cov[0][0] * cov[1][1]) 

95 

96 

97class RegressionMetricStdDevAE(RegressionMetric): 

98 """ 

99 The standard deviation of the absolute error. 

100 This metric does not consider sample weights. 

101 """ 

102 

103 name = "StdDevAE" 

104 

105 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

106 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

107 return np.std(self.compute_abs_errors(y_true, y_predicted)) 

108 

109 

110class RegressionMetricMedianAE(RegressionMetric): 

111 """ 

112 The median absolute error. 

113 This metric does not consider sample weights. 

114 """ 

115 name = "MedianAE" 

116 

117 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

118 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

119 return np.median(self.compute_abs_errors(y_true, y_predicted)) 

120 

121 

122class RegressionMetricFromBinaryClassificationMetric(RegressionMetric): 

123 """ 

124 Supports the computation of binary classification metrics by converting predicted/target values to class labels. 

125 This metric does not consider sample weights. 

126 """ 

127 

128 class ClassGenerator(ABC): 

129 @abstractmethod 

130 def compute_class(self, predicted_value: float) -> bool: 

131 """ 

132 Computes the class from the given value 

133 

134 :param predicted_value: the value predicted by the regressor or regressor target value 

135 :return: the class 

136 """ 

137 pass 

138 

139 @abstractmethod 

140 def get_metric_qualifier(self) -> str: 

141 """ 

142 :return: A (short) string which will be added to the original classification metric's name to 

143 represent the class conversion logic 

144 """ 

145 pass 

146 

147 class ClassGeneratorPositiveBeyond(ClassGenerator): 

148 def __init__(self, min_value_for_positive: float): 

149 self.min_value_for_positive = min_value_for_positive 

150 

151 def compute_class(self, predicted_value: float) -> bool: 

152 return predicted_value >= self.min_value_for_positive 

153 

154 def get_metric_qualifier(self) -> str: 

155 return f">={self.min_value_for_positive}" 

156 

157 def __init__(self, classification_metric: BinaryClassificationMetric, 

158 class_generator: ClassGenerator): 

159 """ 

160 :param classification_metric: the classification metric (which shall consider `True` as the positive label) 

161 :param class_generator: the class generator, which generates `True` and `False` labels from regression values 

162 """ 

163 super().__init__(name=classification_metric.name + f"[{class_generator.get_metric_qualifier()}]", 

164 bounds=classification_metric.bounds) 

165 self.classification_metric = classification_metric 

166 self.class_generator = class_generator 

167 

168 def _apply_class_generator(self, y: np.ndarray) -> np.ndarray: 

169 return np.array([self.class_generator.compute_class(v) for v in y]) 

170 

171 def compute_value(self, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None, 

172 io_data: InputOutputData = None, weights: Optional[np.ndarray] = None): 

173 y_true = self._apply_class_generator(y_true) 

174 y_predicted = self._apply_class_generator(y_predicted) 

175 return self.classification_metric.compute_value(y_true=y_true, y_predicted=y_predicted) 

176 

177 

178class HeatMapColorMapFactory(ABC): 

179 @abstractmethod 

180 def create_color_map(self, min_sample_weight: float, total_weight: float, num_quantization_levels: int): 

181 pass 

182 

183 

184class HeatMapColorMapFactoryWhiteToRed(HeatMapColorMapFactory): 

185 def create_color_map(self, min_sample_weight: float, total_weight: float, num_quantization_levels: int): 

186 color_nothing = (1, 1, 1) # white 

187 color_min_sample = (1, 0.96, 0.96) # very slightly red 

188 color_everything = (0.7, 0, 0) # dark red 

189 return LinearSegmentedColormap.from_list("whiteToRed", 

190 ((0, color_nothing), (min_sample_weight/total_weight, color_min_sample), (1, color_everything)), 

191 num_quantization_levels) 

192 

193 

194DEFAULT_REGRESSION_METRICS = (RegressionMetricRRSE(), RegressionMetricR2(), RegressionMetricMAE(), 

195 RegressionMetricMSE(), RegressionMetricRMSE(), RegressionMetricStdDevAE()) 

196 

197 

198class RegressionEvalStats(PredictionEvalStats["RegressionMetric"]): 

199 """ 

200 Collects data for the evaluation of predicted continuous values and computes corresponding metrics 

201 """ 

202 

203 # class members controlling plot appearance, which can be centrally overridden by a user if necessary 

204 HEATMAP_COLORMAP_FACTORY = HeatMapColorMapFactoryWhiteToRed() 

205 HEATMAP_DIAGONAL_COLOR = "green" 

206 HEATMAP_ERROR_BOUNDARY_VALUE = None 

207 HEATMAP_ERROR_BOUNDARY_COLOR = (0.8, 0.8, 0.8) 

208 SCATTER_PLOT_POINT_COLOR = (0, 0, 1, 0.05) 

209 

210 def __init__(self, y_predicted: Optional[PredictionArray] = None, y_true: Optional[PredictionArray] = None, 

211 metrics: Optional[Sequence["RegressionMetric"]] = None, additional_metrics: Sequence["RegressionMetric"] = None, 

212 model: VectorRegressionModel = None, 

213 io_data: InputOutputData = None, 

214 weights: Optional[Array] = None): 

215 """ 

216 :param y_predicted: the predicted values 

217 :param y_true: the true values 

218 :param metrics: the metrics to compute for evaluation; if None, will use DEFAULT_REGRESSION_METRICS 

219 :param additional_metrics: the metrics to additionally compute 

220 :param weights: optional data point weights 

221 """ 

222 self.model = model 

223 self.ioData = io_data 

224 

225 if metrics is None: 

226 metrics = DEFAULT_REGRESSION_METRICS 

227 metrics = list(metrics) 

228 

229 super().__init__(y_predicted, y_true, metrics, additional_metrics=additional_metrics, weights=weights) 

230 

231 def compute_metric_value(self, metric: RegressionMetric) -> float: 

232 return metric.compute_value_for_eval_stats(self) 

233 

234 def compute_mse(self): 

235 """Computes the mean squared error (MSE)""" 

236 return self.compute_metric_value(RegressionMetricMSE()) 

237 

238 def compute_rrse(self): 

239 """Computes the root relative squared error""" 

240 return self.compute_metric_value(RegressionMetricRRSE()) 

241 

242 def compute_pcc(self): 

243 """Gets the Pearson correlation coefficient (PCC)""" 

244 return self.compute_metric_value(RegressionMetricPCC()) 

245 

246 def compute_r2(self): 

247 """Gets the R^2 score""" 

248 return self.compute_metric_value(RegressionMetricR2()) 

249 

250 def compute_mae(self): 

251 """Gets the mean absolute error""" 

252 return self.compute_metric_value(RegressionMetricMAE()) 

253 

254 def compute_rmse(self): 

255 """Gets the root mean squared error""" 

256 return self.compute_metric_value(RegressionMetricRMSE()) 

257 

258 def compute_std_dev_ae(self): 

259 """Gets the standard deviation of the absolute error""" 

260 return self.compute_metric_value(RegressionMetricStdDevAE()) 

261 

262 def create_eval_stats_collection(self) -> "RegressionEvalStatsCollection": 

263 """ 

264 For the case where we collected data on multiple dimensions, obtain a stats collection where 

265 each object in the collection holds stats on just one dimension 

266 """ 

267 if self.y_true_multidim is None: 

268 raise Exception("No multi-dimensional data was collected") 

269 dim = len(self.y_true_multidim) 

270 stats_list = [] 

271 for i in range(dim): 

272 stats = RegressionEvalStats(self.y_predicted_multidim[i], self.y_true_multidim[i]) 

273 stats_list.append(stats) 

274 return RegressionEvalStatsCollection(stats_list) 

275 

276 def plot_error_distribution(self, bins="auto", title_add=None) -> Optional[plt.Figure]: 

277 """ 

278 :param bins: bin specification (see :class:`HistogramPlot`) 

279 :param title_add: a string to add to the title (on a second line) 

280 

281 :return: the resulting figure object or None 

282 """ 

283 errors = np.array(self.y_predicted) - np.array(self.y_true) 

284 title = "Prediction Error Distribution" 

285 if title_add is not None: 

286 title += "\n" + title_add 

287 if bins == "auto" and len(errors) < 100: 

288 bins = 10 # seaborn can crash with low number of data points and bins="auto" (tries to allocate vast amounts of memory) 

289 plot = HistogramPlot(errors, bins=bins, kde=True) 

290 plot.title(title) 

291 plot.xlabel("error (prediction - ground truth)") 

292 plot.ylabel("probability density") 

293 return plot.fig 

294 

295 def plot_scatter_ground_truth_predictions(self, figure=True, title_add=None, **kwargs) -> Optional[plt.Figure]: 

296 """ 

297 :param figure: whether to plot in a separate figure and return that figure 

298 :param title_add: a string to be added to the title in a second line 

299 :param kwargs: parameters to be passed on to plt.scatter() 

300 

301 :return: the resulting figure object or None 

302 """ 

303 fig = None 

304 title = "Scatter Plot of Predicted Values vs. Ground Truth" 

305 if title_add is not None: 

306 title += "\n" + title_add 

307 if figure: 

308 fig = plt.figure(title.replace("\n", " ")) 

309 y_range = [min(self.y_true), max(self.y_true)] 

310 plt.scatter(self.y_true, self.y_predicted, c=[self.SCATTER_PLOT_POINT_COLOR], zorder=2, **kwargs) 

311 plt.plot(y_range, y_range, '-', lw=1, label="_not in legend", color="green", zorder=1) 

312 plt.xlabel("ground truth") 

313 plt.ylabel("prediction") 

314 plt.title(title) 

315 return fig 

316 

317 def plot_heatmap_ground_truth_predictions(self, figure=True, cmap=None, bins=60, title_add=None, error_boundary: Optional[float] = None, 

318 weighted: bool = False, ax: Optional[plt.Axes] = None, 

319 **kwargs) -> Optional[plt.Figure]: 

320 """ 

321 :param figure: whether to create a new figure and return that figure (only applies if ax is None) 

322 :param cmap: the colour map to use (see corresponding parameter of plt.imshow for further information); if None, use factory 

323 defined in HEATMAP_COLORMAP_FACTORY (which can be centrally set to achieve custom behaviour throughout an application) 

324 :param bins: how many bins to use for constructing the heatmap 

325 :param title_add: a string to add to the title (on a second line) 

326 :param error_boundary: if not None, add two lines (above and below the diagonal) indicating this absolute regression error boundary; 

327 if None (default), use static member HEATMAP_ERROR_BOUNDARY_VALUE (which is also None by default, but can be centrally set 

328 to achieve custom behaviour throughout an application) 

329 :param weighted: whether to consider data point weights 

330 :param ax: the axis to plot in. If None, use the current axes (which will be the axis of the newly created figure if figure=True). 

331 :param kwargs: will be passed to plt.imshow() 

332 

333 :return: the newly created figure object (if figure=True) or None 

334 """ 

335 fig = None 

336 title = "Heat Map of Predicted Values vs. Ground Truth" 

337 if title_add: 

338 title += "\n" + title_add 

339 if figure and ax is None: 

340 fig = plt.figure(title.replace("\n", " ")) 

341 if ax is None: 

342 ax = plt.gca() 

343 

344 y_range = [min(min(self.y_true), min(self.y_predicted)), max(max(self.y_true), max(self.y_predicted))] 

345 

346 # diagonal 

347 ax.plot(y_range, y_range, '-', lw=0.75, label="_not in legend", color=self.HEATMAP_DIAGONAL_COLOR, zorder=2) 

348 

349 # error boundaries 

350 if error_boundary is None: 

351 error_boundary = self.HEATMAP_ERROR_BOUNDARY_VALUE 

352 if error_boundary is not None: 

353 d = np.array(y_range) 

354 offs = np.array([error_boundary, error_boundary]) 

355 ax.plot(d, d + offs, '-', lw=0.75, label="_not in legend", color=self.HEATMAP_ERROR_BOUNDARY_COLOR, zorder=2) 

356 ax.plot(d, d - offs, '-', lw=0.75, label="_not in legend", color=self.HEATMAP_ERROR_BOUNDARY_COLOR, zorder=2) 

357 

358 # heat map 

359 weights = None if not weighted else self.weights 

360 heatmap, _, _ = np.histogram2d(self.y_true, self.y_predicted, range=(y_range, y_range), bins=bins, density=False, weights=weights) 

361 extent = (y_range[0], y_range[1], y_range[0], y_range[1]) 

362 if cmap is None: 

363 num_quantization_levels = min(1000, len(self.y_predicted)) 

364 if not weighted: 

365 min_sample_weight = 1.0 

366 total_weight = len(self.y_predicted) 

367 else: 

368 min_sample_weight = np.min(self.weights) 

369 total_weight = np.sum(self.weights) 

370 cmap = self.HEATMAP_COLORMAP_FACTORY.create_color_map(min_sample_weight, total_weight, num_quantization_levels) 

371 ax.imshow(heatmap.T, extent=extent, origin='lower', interpolation="none", cmap=cmap, zorder=1, **kwargs) 

372 

373 ax.set_xlabel("ground truth") 

374 ax.set_ylabel("prediction") 

375 ax.set_title(title) 

376 return fig 

377 

378 

379class RegressionEvalStatsCollection(EvalStatsCollection[RegressionEvalStats, RegressionMetric]): 

380 def __init__(self, eval_stats_list: List[RegressionEvalStats]): 

381 super().__init__(eval_stats_list) 

382 self.globalStats = None 

383 

384 def get_combined_eval_stats(self) -> RegressionEvalStats: 

385 if self.globalStats is None: 

386 y_true = np.concatenate([evalStats.y_true for evalStats in self.statsList]) 

387 y_predicted = np.concatenate([evalStats.y_predicted for evalStats in self.statsList]) 

388 es0 = self.statsList[0] 

389 self.globalStats = RegressionEvalStats(y_predicted, y_true, metrics=es0.metrics) 

390 return self.globalStats 

391 

392 

393class RegressionEvalStatsPlot(EvalStatsPlot[RegressionEvalStats], ABC): 

394 pass 

395 

396 

397class RegressionEvalStatsPlotErrorDistribution(RegressionEvalStatsPlot): 

398 def create_figure(self, eval_stats: RegressionEvalStats, subtitle: str) -> plt.Figure: 

399 return eval_stats.plot_error_distribution(title_add=subtitle) 

400 

401 

402class RegressionEvalStatsPlotHeatmapGroundTruthPredictions(RegressionEvalStatsPlot): 

403 def __init__(self, weighted: bool = False): 

404 self.weighted = weighted 

405 

406 def is_applicable(self, eval_stats: RegressionEvalStats) -> bool: 

407 if self.weighted: 

408 return eval_stats.weights is not None 

409 else: 

410 return True 

411 

412 def create_figure(self, eval_stats: RegressionEvalStats, subtitle: str) -> plt.Figure: 

413 return eval_stats.plot_heatmap_ground_truth_predictions(title_add=subtitle, weighted=self.weighted) 

414 

415 

416class RegressionEvalStatsPlotScatterGroundTruthPredictions(RegressionEvalStatsPlot): 

417 def create_figure(self, eval_stats: RegressionEvalStats, subtitle: str) -> plt.Figure: 

418 return eval_stats.plot_scatter_ground_truth_predictions(title_add=subtitle)