Coverage for src/sensai/evaluation/eval_stats/eval_stats_regression.py: 58%
182 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import logging
2from abc import abstractmethod, ABC
3from typing import List, Sequence, Optional
5import numpy as np
6from matplotlib import pyplot as plt
7from matplotlib.colors import LinearSegmentedColormap
9from .eval_stats_base import PredictionEvalStats, Metric, EvalStatsCollection, PredictionArray, EvalStatsPlot
10from ...vector_model import VectorRegressionModel, InputOutputData
11from ...util.plot import HistogramPlot
13log = logging.getLogger(__name__)
16class RegressionMetric(Metric["RegressionEvalStats"], ABC):
17 def compute_value_for_eval_stats(self, eval_stats: "RegressionEvalStats", model: VectorRegressionModel = None,
18 io_data: InputOutputData = None):
19 return self.compute_value(np.array(eval_stats.y_true), np.array(eval_stats.y_predicted), model=model, io_data=io_data)
21 @classmethod
22 @abstractmethod
23 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
24 io_data: InputOutputData = None):
25 pass
27 @classmethod
28 def compute_errors(cls, y_true: np.ndarray, y_predicted: np.ndarray):
29 return y_predicted - y_true
31 @classmethod
32 def compute_abs_errors(cls, y_true: np.ndarray, y_predicted: np.ndarray):
33 return np.abs(cls.compute_errors(y_true, y_predicted))
36class RegressionMetricMAE(RegressionMetric):
37 name = "MAE"
39 @classmethod
40 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
41 io_data: InputOutputData = None):
42 return np.mean(cls.compute_abs_errors(y_true, y_predicted))
45class RegressionMetricMSE(RegressionMetric):
46 name = "MSE"
48 @classmethod
49 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
50 io_data: InputOutputData = None):
51 residuals = y_predicted - y_true
52 return np.sum(residuals * residuals) / len(residuals)
55class RegressionMetricRMSE(RegressionMetric):
56 name = "RMSE"
58 @classmethod
59 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
60 io_data: InputOutputData = None):
61 errors = cls.compute_errors(y_true, y_predicted)
62 return np.sqrt(np.mean(errors * errors))
65class RegressionMetricRRSE(RegressionMetric):
66 name = "RRSE"
68 @classmethod
69 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
70 io_data: InputOutputData = None):
71 mean_y = np.mean(y_true)
72 residuals = y_predicted - y_true
73 mean_deviation = y_true - mean_y
74 return np.sqrt(np.sum(residuals * residuals) / np.sum(mean_deviation * mean_deviation))
77class RegressionMetricR2(RegressionMetric):
78 name = "R2"
80 @classmethod
81 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
82 io_data: InputOutputData = None):
83 rrse = RegressionMetricRRSE.compute_value(y_true, y_predicted)
84 return 1.0 - rrse*rrse
87class RegressionMetricPCC(RegressionMetric):
88 name = "PCC"
90 @classmethod
91 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
92 io_data: InputOutputData = None):
93 cov = np.cov([y_true, y_predicted])
94 return cov[0][1] / np.sqrt(cov[0][0] * cov[1][1])
97class RegressionMetricStdDevAE(RegressionMetric):
98 name = "StdDevAE"
100 @classmethod
101 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
102 io_data: InputOutputData = None):
103 return np.std(cls.compute_abs_errors(y_true, y_predicted))
106class RegressionMetricMedianAE(RegressionMetric):
107 name = "MedianAE"
109 @classmethod
110 def compute_value(cls, y_true: np.ndarray, y_predicted: np.ndarray, model: VectorRegressionModel = None,
111 io_data: InputOutputData = None):
112 return np.median(cls.compute_abs_errors(y_true, y_predicted))
115DEFAULT_REGRESSION_METRICS = (RegressionMetricRRSE(), RegressionMetricR2(), RegressionMetricMAE(),
116 RegressionMetricMSE(), RegressionMetricRMSE(), RegressionMetricStdDevAE())
119class RegressionEvalStats(PredictionEvalStats["RegressionMetric"]):
120 """
121 Collects data for the evaluation of predicted continuous values and computes corresponding metrics
122 """
124 # class members controlling plot appearance, which can be centrally overridden by a user if necessary
125 HEATMAP_COLORMAP_FACTORY = lambda self: LinearSegmentedColormap.from_list("whiteToRed",
126 ((0, (1, 1, 1)), (1/len(self.y_predicted), (1, 0.96, 0.96)), (1, (0.7, 0, 0))), len(self.y_predicted))
127 HEATMAP_DIAGONAL_COLOR = "green"
128 HEATMAP_ERROR_BOUNDARY_VALUE = None
129 HEATMAP_ERROR_BOUNDARY_COLOR = (0.8, 0.8, 0.8)
130 SCATTER_PLOT_POINT_COLOR = (0, 0, 1, 0.05)
132 def __init__(self, y_predicted: Optional[PredictionArray] = None, y_true: Optional[PredictionArray] = None,
133 metrics: Optional[Sequence["RegressionMetric"]] = None, additional_metrics: Sequence["RegressionMetric"] = None,
134 model: VectorRegressionModel = None, io_data: InputOutputData = None):
135 """
136 :param y_predicted: the predicted values
137 :param y_true: the true values
138 :param metrics: the metrics to compute for evaluation; if None, will use DEFAULT_REGRESSION_METRICS
139 :param additional_metrics: the metrics to additionally compute
140 """
141 self.model = model
142 self.ioData = io_data
144 if metrics is None:
145 metrics = DEFAULT_REGRESSION_METRICS
146 metrics = list(metrics)
148 super().__init__(y_predicted, y_true, metrics, additional_metrics=additional_metrics)
150 def compute_metric_value(self, metric: RegressionMetric) -> float:
151 return metric.compute_value_for_eval_stats(self, model=self.model, io_data=self.ioData)
153 def compute_mse(self):
154 """Computes the mean squared error (MSE)"""
155 return self.compute_metric_value(RegressionMetricMSE())
157 def compute_rrse(self):
158 """Computes the root relative squared error"""
159 return self.compute_metric_value(RegressionMetricRRSE())
161 def compute_pcc(self):
162 """Gets the Pearson correlation coefficient (PCC)"""
163 return self.compute_metric_value(RegressionMetricPCC())
165 def compute_r2(self):
166 """Gets the R^2 score"""
167 return self.compute_metric_value(RegressionMetricR2())
169 def compute_mae(self):
170 """Gets the mean absolute error"""
171 return self.compute_metric_value(RegressionMetricMAE())
173 def compute_rmse(self):
174 """Gets the root mean squared error"""
175 return self.compute_metric_value(RegressionMetricRMSE())
177 def compute_std_dev_ae(self):
178 """Gets the standard deviation of the absolute error"""
179 return self.compute_metric_value(RegressionMetricStdDevAE())
181 def create_eval_stats_collection(self) -> "RegressionEvalStatsCollection":
182 """
183 For the case where we collected data on multiple dimensions, obtain a stats collection where
184 each object in the collection holds stats on just one dimension
185 """
186 if self.y_true_multidim is None:
187 raise Exception("No multi-dimensional data was collected")
188 dim = len(self.y_true_multidim)
189 stats_list = []
190 for i in range(dim):
191 stats = RegressionEvalStats(self.y_predicted_multidim[i], self.y_true_multidim[i])
192 stats_list.append(stats)
193 return RegressionEvalStatsCollection(stats_list)
195 def plot_error_distribution(self, bins="auto", title_add=None) -> Optional[plt.Figure]:
196 """
197 :param bins: bin specification (see :class:`HistogramPlot`)
198 :param title_add: a string to add to the title (on a second line)
200 :return: the resulting figure object or None
201 """
202 errors = np.array(self.y_predicted) - np.array(self.y_true)
203 title = "Prediction Error Distribution"
204 if title_add is not None:
205 title += "\n" + title_add
206 if bins == "auto" and len(errors) < 100:
207 bins = 10 # seaborn can crash with low number of data points and bins="auto" (tries to allocate vast amounts of memory)
208 plot = HistogramPlot(errors, bins=bins, kde=True)
209 plot.title(title)
210 plot.xlabel("error (prediction - ground truth)")
211 plot.ylabel("probability density")
212 return plot.fig
214 def plot_scatter_ground_truth_predictions(self, figure=True, title_add=None, **kwargs) -> Optional[plt.Figure]:
215 """
216 :param figure: whether to plot in a separate figure and return that figure
217 :param title_add: a string to be added to the title in a second line
218 :param kwargs: parameters to be passed on to plt.scatter()
220 :return: the resulting figure object or None
221 """
222 fig = None
223 title = "Scatter Plot of Predicted Values vs. Ground Truth"
224 if title_add is not None:
225 title += "\n" + title_add
226 if figure:
227 fig = plt.figure(title.replace("\n", " "))
228 y_range = [min(self.y_true), max(self.y_true)]
229 plt.scatter(self.y_true, self.y_predicted, c=[self.SCATTER_PLOT_POINT_COLOR], zorder=2, **kwargs)
230 plt.plot(y_range, y_range, '-', lw=1, label="_not in legend", color="green", zorder=1)
231 plt.xlabel("ground truth")
232 plt.ylabel("prediction")
233 plt.title(title)
234 return fig
236 def plot_heatmap_ground_truth_predictions(self, figure=True, cmap=None, bins=60, title_add=None, error_boundary: Optional[float] = None,
237 **kwargs) -> Optional[plt.Figure]:
238 """
239 :param figure: whether to plot in a separate figure and return that figure
240 :param cmap: the colour map to use (see corresponding parameter of plt.imshow for further information); if None, use factory
241 defined in HEATMAP_COLORMAP_FACTORY (which can be centrally set to achieve custom behaviour throughout an application)
242 :param bins: how many bins to use for constructing the heatmap
243 :param title_add: a string to add to the title (on a second line)
244 :param error_boundary: if not None, add two lines (above and below the diagonal) indicating this absolute regression error boundary;
245 if None (default), use static member HEATMAP_ERROR_BOUNDARY_VALUE (which is also None by default, but can be centrally set
246 to achieve custom behaviour throughout an application)
247 :param kwargs: will be passed to plt.imshow()
249 :return: the resulting figure object or None
250 """
251 fig = None
252 title = "Heat Map of Predicted Values vs. Ground Truth"
253 if title_add:
254 title += "\n" + title_add
255 if figure:
256 fig = plt.figure(title.replace("\n", " "))
257 y_range = [min(min(self.y_true), min(self.y_predicted)), max(max(self.y_true), max(self.y_predicted))]
259 # diagonal
260 plt.plot(y_range, y_range, '-', lw=0.75, label="_not in legend", color=self.HEATMAP_DIAGONAL_COLOR, zorder=2)
262 # error boundaries
263 if error_boundary is None:
264 error_boundary = self.HEATMAP_ERROR_BOUNDARY_VALUE
265 if error_boundary is not None:
266 d = np.array(y_range)
267 offs = np.array([error_boundary, error_boundary])
268 plt.plot(d, d + offs, '-', lw=0.75, label="_not in legend", color=self.HEATMAP_ERROR_BOUNDARY_COLOR, zorder=2)
269 plt.plot(d, d - offs, '-', lw=0.75, label="_not in legend", color=self.HEATMAP_ERROR_BOUNDARY_COLOR, zorder=2)
271 # heat map
272 heatmap, _, _ = np.histogram2d(self.y_true, self.y_predicted, range=[y_range, y_range], bins=bins, density=False)
273 extent = [y_range[0], y_range[1], y_range[0], y_range[1]]
274 if cmap is None:
275 cmap = self.HEATMAP_COLORMAP_FACTORY()
276 plt.imshow(heatmap.T, extent=extent, origin='lower', interpolation="none", cmap=cmap, zorder=1, **kwargs)
278 plt.xlabel("ground truth")
279 plt.ylabel("prediction")
280 plt.title(title)
281 return fig
284class RegressionEvalStatsCollection(EvalStatsCollection[RegressionEvalStats, RegressionMetric]):
285 def __init__(self, eval_stats_list: List[RegressionEvalStats]):
286 super().__init__(eval_stats_list)
287 self.globalStats = None
289 def get_combined_eval_stats(self) -> RegressionEvalStats:
290 if self.globalStats is None:
291 y_true = np.concatenate([evalStats.y_true for evalStats in self.statsList])
292 y_predicted = np.concatenate([evalStats.y_predicted for evalStats in self.statsList])
293 es0 = self.statsList[0]
294 self.globalStats = RegressionEvalStats(y_predicted, y_true, metrics=es0.metrics)
295 return self.globalStats
298class RegressionEvalStatsPlot(EvalStatsPlot[RegressionEvalStats], ABC):
299 pass
302class RegressionEvalStatsPlotErrorDistribution(RegressionEvalStatsPlot):
303 def create_figure(self, eval_stats: RegressionEvalStats, subtitle: str) -> plt.Figure:
304 return eval_stats.plot_error_distribution(title_add=subtitle)
307class RegressionEvalStatsPlotHeatmapGroundTruthPredictions(RegressionEvalStatsPlot):
308 def create_figure(self, eval_stats: RegressionEvalStats, subtitle: str) -> plt.Figure:
309 return eval_stats.plot_heatmap_ground_truth_predictions(title_add=subtitle)
312class RegressionEvalStatsPlotScatterGroundTruthPredictions(RegressionEvalStatsPlot):
313 def create_figure(self, eval_stats: RegressionEvalStats, subtitle: str) -> plt.Figure:
314 return eval_stats.plot_scatter_ground_truth_predictions(title_add=subtitle)