Coverage for src/sensai/evaluation/eval_stats/eval_stats_base.py: 50%
212 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1from abc import ABC, abstractmethod
2from typing import Generic, TypeVar, List, Union, Dict, Sequence, Optional, Tuple, Callable
4import numpy as np
5import pandas as pd
6from matplotlib import pyplot as plt
8from ...util.pickle import setstate
9from ...util.plot import ScatterPlot, HistogramPlot, Plot, HeatMapPlot
10from ...util.string import ToStringMixin, dict_string
11from ...vector_model import VectorModel
13# Note: in the 2020.2 version of PyCharm passing strings to bound is highlighted as error
14# It does not cause runtime errors and the static type checker ignores the bound anyway, so it does not matter for now.
15# However, this might cause problems with type checking in the future. Therefore, I moved the definition of TEvalStats
16# below the definition of EvalStats. Unfortunately, the dependency in generics between EvalStats and Metric
17# does not allow to define both, TMetric and TEvalStats, properly. For now we have to leave it with the bound as string
18# and hope for the best in the future
19TMetric = TypeVar("TMetric", bound="Metric")
20TVectorModel = TypeVar("TVectorModel", bound=VectorModel)
22Array = Union[np.ndarray, pd.Series, list]
23PredictionArray = Union[np.ndarray, pd.Series, pd.DataFrame, list]
26class EvalStats(Generic[TMetric], ToStringMixin):
27 def __init__(self, metrics: List[TMetric], additional_metrics: List[TMetric] = None):
28 if len(metrics) == 0:
29 raise ValueError("No metrics provided")
30 self.metrics = metrics
31 # Implementations of EvalStats will typically provide default metrics, therefore we include
32 # the possibility for passing additional metrics here
33 if additional_metrics is not None:
34 self.metrics = self.metrics + additional_metrics
35 self.name = None
37 def set_name(self, name: str):
38 self.name = name
40 def add_metric(self, metric: TMetric):
41 self.metrics.append(metric)
43 def compute_metric_value(self, metric: TMetric) -> float:
44 return metric.compute_value_for_eval_stats(self)
46 def metrics_dict(self) -> Dict[str, float]:
47 """
48 Computes all metrics
50 :return: a dictionary mapping metric names to values
51 """
52 d = {}
53 for metric in self.metrics:
54 d[metric.name] = self.compute_metric_value(metric)
55 return d
57 def get_all(self) -> Dict[str, float]:
58 """Alias for metricsDict; may be deprecated in the future"""
59 return self.metrics_dict()
61 def _tostring_object_info(self) -> str:
62 return dict_string(self.metrics_dict())
65TEvalStats = TypeVar("TEvalStats", bound=EvalStats)
68class Metric(Generic[TEvalStats], ABC):
69 name: str
71 def __init__(self, name: str = None, bounds: Optional[Tuple[float, float]] = None):
72 """
73 :param name: the name of the metric; if None use the class' name attribute
74 :param bounds: the minimum and maximum values the metric can take on (or None if the bounds are not specified)
75 """
76 # this raises an attribute error if a subclass does not specify a name as a static attribute nor as parameter
77 self.name = name if name is not None else self.__class__.name
78 self.bounds = bounds
80 @abstractmethod
81 def compute_value_for_eval_stats(self, eval_stats: TEvalStats) -> float:
82 pass
84 def get_paired_metrics(self) -> List[TMetric]:
85 """
86 Gets a list of metrics that should be considered together with this metric (e.g. for paired visualisations/plots).
87 The direction of the pairing should be such that if this metric is "x", the other is "y" for x-y type visualisations.
89 :return: a list of metrics
90 """
91 return []
93 def has_finite_bounds(self) -> bool:
94 return self.bounds is not None and not any((np.isinf(x) for x in self.bounds))
97class EvalStatsCollection(Generic[TEvalStats, TMetric], ABC):
98 def __init__(self, eval_stats_list: List[TEvalStats]):
99 self.statsList = eval_stats_list
100 metric_names_set = None
101 metrics_list = []
102 for es in eval_stats_list:
103 metrics = es.metrics_dict()
104 current_metric_names_set = set(metrics.keys())
105 if metric_names_set is None:
106 metric_names_set = current_metric_names_set
107 else:
108 if metric_names_set != current_metric_names_set:
109 raise Exception(f"Inconsistent set of metrics in evaluation stats collection: "
110 f"Got {metric_names_set} for one instance, {current_metric_names_set} for another")
111 metrics_list.append(metrics)
112 metric_names = sorted(metrics_list[0].keys())
113 self._valuesByMetricName = {metric: [d[metric] for d in metrics_list] for metric in metric_names}
114 self._metrics: List[TMetric] = eval_stats_list[0].metrics
116 def get_values(self, metric_name: str):
117 return self._valuesByMetricName[metric_name]
119 def get_metric_names(self) -> List[str]:
120 return list(self._valuesByMetricName.keys())
122 def get_metrics(self) -> List[TMetric]:
123 return self._metrics
125 def get_metric_by_name(self, name: str) -> Optional[TMetric]:
126 for m in self._metrics:
127 if m.name == name:
128 return m
129 return None
131 def has_metric(self, metric: Union[Metric, str]) -> bool:
132 if type(metric) != str:
133 metric = metric.name
134 return metric in self._valuesByMetricName
136 def agg_metrics_dict(self, agg_fns=(np.mean, np.std)) -> Dict[str, float]:
137 agg = {}
138 for metric, values in self._valuesByMetricName.items():
139 for agg_fn in agg_fns:
140 agg[f"{agg_fn.__name__}[{metric}]"] = float(agg_fn(values))
141 return agg
143 def mean_metrics_dict(self) -> Dict[str, float]:
144 metrics = {metric: np.mean(values) for (metric, values) in self._valuesByMetricName.items()}
145 return metrics
147 def plot_distribution(self, metric_name: str, subtitle: Optional[str] = None, bins=None, kde=False, cdf=False,
148 cdf_complementary=False, stat="proportion", **kwargs) -> plt.Figure:
149 """
150 Plots the distribution of a metric as a histogram
152 :param metric_name: name of the metric for which to plot the distribution (histogram) across evaluations
153 :param subtitle: the subtitle to add, if any
154 :param bins: the histogram bins (number of bins or boundaries); metrics bounds will be used to define the x limits.
155 If None, use 'auto' bins
156 :param kde: whether to add a kernel density estimator plot
157 :param cdf: whether to add the cumulative distribution function (cdf)
158 :param cdf_complementary: whether to plot, if ``cdf`` is True, the complementary cdf instead of the regular cdf
159 :param stat: the statistic to compute for each bin ('percent', 'probability'='proportion', 'count', 'frequency' or 'density'),
160 y-axis value
161 :param kwargs: additional parameters to pass to seaborn.histplot (see https://seaborn.pydata.org/generated/seaborn.histplot.html)
162 :return: the plot
163 """
164 # define bins based on metric bounds where available
165 x_tick = None
166 if bins is None or type(bins) == int:
167 metric = self.get_metric_by_name(metric_name)
168 if metric.bounds == (0, 1):
169 x_tick = 0.1
170 if bins is None:
171 num_bins = 10 if cdf else 20
172 else:
173 num_bins = bins
174 bins = np.linspace(0, 1, num_bins+1)
175 else:
176 bins = "auto"
178 values = self._valuesByMetricName[metric_name]
179 title = metric_name
180 if subtitle is not None:
181 title += "\n" + subtitle
182 plot = HistogramPlot(values, bins=bins, stat=stat, kde=kde, cdf=cdf, cdf_complementary=cdf_complementary, **kwargs).title(title)
183 if x_tick is not None:
184 plot.xtick_major(x_tick)
185 return plot.fig
187 def _plot_xy(self, metric_name_x, metric_name_y, plot_factory: Callable[[Sequence, Sequence], Plot], adjust_bounds: bool) -> plt.Figure:
188 def axlim(bounds):
189 min_value, max_value = bounds
190 diff = max_value - min_value
191 return (min_value - 0.05 * diff, max_value + 0.05 * diff)
193 x = self._valuesByMetricName[metric_name_x]
194 y = self._valuesByMetricName[metric_name_y]
195 plot = plot_factory(x, y)
196 plot.xlabel(metric_name_x)
197 plot.ylabel(metric_name_y)
198 mx = self.get_metric_by_name(metric_name_x)
199 if adjust_bounds and mx.has_finite_bounds():
200 plot.xlim(*axlim(mx.bounds))
201 my = self.get_metric_by_name(metric_name_y)
202 if adjust_bounds and my.has_finite_bounds():
203 plot.ylim(*axlim(my.bounds))
204 return plot.fig
206 def plot_scatter(self, metric_name_x: str, metric_name_y: str) -> plt.Figure:
207 return self._plot_xy(metric_name_x, metric_name_y, ScatterPlot, adjust_bounds=True)
209 def plot_heat_map(self, metric_name_x: str, metric_name_y: str) -> plt.Figure:
210 return self._plot_xy(metric_name_x, metric_name_y, HeatMapPlot, adjust_bounds=False)
212 def to_data_frame(self) -> pd.DataFrame:
213 """
214 :return: a DataFrame with the evaluation metrics from all contained EvalStats objects;
215 the EvalStats' name field being used as the index if it is set
216 """
217 data = dict(self._valuesByMetricName)
218 index = [stats.name for stats in self.statsList]
219 if len([n for n in index if n is not None]) == 0:
220 index = None
221 return pd.DataFrame(data, index=index)
223 def get_global_stats(self) -> TEvalStats:
224 """
225 Alias for `getCombinedEvalStats`
226 """
227 return self.get_combined_eval_stats()
229 @abstractmethod
230 def get_combined_eval_stats(self) -> TEvalStats:
231 """
232 :return: an EvalStats object that combines the data from all contained EvalStats objects
233 """
234 pass
236 def __str__(self):
237 return f"{self.__class__.__name__}[" + \
238 ", ".join([f"{key}={self.agg_metrics_dict()[key]:.4f}" for key in self._valuesByMetricName]) + "]"
241class PredictionEvalStats(EvalStats[TMetric], ABC):
242 """
243 Collects data for the evaluation of predicted values (including multi-dimensional predictions)
244 and computes corresponding metrics
245 """
246 def __init__(self, y_predicted: Optional[PredictionArray], y_true: Optional[PredictionArray],
247 metrics: List[TMetric], additional_metrics: List[TMetric] = None,
248 weights: Optional[Array] = None):
249 """
250 :param y_predicted: sequence of predicted values, or, in case of multi-dimensional predictions, either a data frame with
251 one column per dimension or a nested sequence of values
252 :param y_true: sequence of ground truth labels of same shape as y_predicted
253 :param metrics: list of metrics to be computed on the provided data
254 :param additional_metrics: the metrics to additionally compute. This should only be provided if metrics is None
255 :param weights: weights for each data point contained in `y_predicted` and `y_true`
256 """
257 self.y_true = []
258 self.y_predicted = []
259 self.weights: Optional[List[float]] = None
260 self.y_true_multidim = None
261 self.y_predicted_multidim = None
262 if y_predicted is not None:
263 self.add_all(y_predicted=y_predicted, y_true=y_true, weights=weights)
264 super().__init__(metrics, additional_metrics=additional_metrics)
266 def __setstate__(self, state):
267 return setstate(PredictionEvalStats, self, state, new_optional_properties=["weights"])
269 def add(self, y_predicted, y_true, weight: Optional[float] = None) -> None:
270 """
271 Adds a single pair of values to the evaluation
273 :param y_predicted: the value predicted by the model
274 :param y_true: the true value
275 """
276 self.y_true.append(y_true)
277 self.y_predicted.append(y_predicted)
278 if weight is not None:
279 if self.weights is None:
280 self.weights = []
281 self.weights.append(weight)
283 def add_all(self, y_predicted: PredictionArray, y_true: PredictionArray, weights: Optional[Array] = None) -> None:
284 """
285 :param y_predicted: sequence of predicted values, or, in case of multi-dimensional predictions, either a data frame with
286 one column per dimension or a nested sequence of values
287 :param y_true: sequence of ground truth labels of same shape as y_predicted
288 :param weights: optional weights of data points
289 """
290 def is_sequence(x):
291 return isinstance(x, pd.Series) or isinstance(x, list) or isinstance(x, np.ndarray)
293 if is_sequence(y_predicted) and is_sequence(y_true):
294 a, b = len(y_predicted), len(y_true)
295 if a != b:
296 raise Exception(f"Lengths differ (predicted {a}, truth {b})")
297 if a > 0:
298 first_item = y_predicted.iloc[0] if isinstance(y_predicted, pd.Series) else y_predicted[0]
299 is_nested_sequence = is_sequence(first_item)
300 if is_nested_sequence:
301 for y_true_i, y_predicted_i in zip(y_true, y_predicted):
302 self.add_all(y_predicted=y_predicted_i, y_true=y_true_i)
303 else:
304 self.y_true.extend(y_true)
305 self.y_predicted.extend(y_predicted)
306 elif isinstance(y_predicted, pd.DataFrame) and isinstance(y_true, pd.DataFrame):
307 # keep track of multidimensional data (to be used later in getEvalStatsCollection)
308 y_predicted_multidim = y_predicted.values
309 y_true_multidim = y_true.values
310 dim = y_predicted_multidim.shape[1]
311 if dim != y_true_multidim.shape[1]:
312 raise Exception("Dimension mismatch")
313 if self.y_true_multidim is None:
314 self.y_predicted_multidim = [[] for _ in range(dim)]
315 self.y_true_multidim = [[] for _ in range(dim)]
316 if len(self.y_predicted_multidim) != dim:
317 raise Exception("Dimension mismatch")
318 for i in range(dim):
319 self.y_predicted_multidim[i].extend(y_predicted_multidim[:, i])
320 self.y_true_multidim[i].extend(y_true_multidim[:, i])
321 # convert to flat data for this stats object
322 y_predicted = y_predicted_multidim.reshape(-1)
323 y_true = y_true_multidim.reshape(-1)
324 self.y_true.extend(y_true)
325 self.y_predicted.extend(y_predicted)
326 else:
327 raise Exception(f"Unhandled data types: {type(y_predicted)}, {type(y_true)}")
329 if weights is not None:
330 if self.weights is None:
331 self.weights = []
332 assert len(weights) == len(self.y_predicted) - len(self.weights), "Length of weights does not match"
333 self.weights.extend(weights)
335 def _tostring_object_info(self) -> str:
336 return f"{super()._tostring_object_info()}, N={len(self.y_predicted)}"
339def mean_stats(eval_stats_list: Sequence[EvalStats]) -> Dict[str, float]:
340 """
341 For a list of EvalStats objects compute the mean values of all metrics in a dictionary.
342 Assumes that all provided EvalStats have the same metrics
343 """
344 dicts = [s.metrics_dict() for s in eval_stats_list]
345 metrics = dicts[0].keys()
346 return {m: np.mean([d[m] for d in dicts]) for m in metrics}
349class EvalStatsPlot(Generic[TEvalStats], ABC):
350 @abstractmethod
351 def create_figure(self, eval_stats: TEvalStats, subtitle: str) -> Optional[plt.Figure]:
352 """
353 :param eval_stats: the evaluation stats from which to generate the plot
354 :param subtitle: the plot's subtitle
355 :return: the figure or None if this plot is not applicable/cannot be created
356 """
357 pass
359 def is_applicable(self, eval_stats: TEvalStats) -> bool:
360 return True