Coverage for src/sensai/evaluation/eval_stats/eval_stats_base.py: 52%
196 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from abc import ABC, abstractmethod
2from typing import Generic, TypeVar, List, Union, Dict, Sequence, Optional, Tuple, Callable
4import numpy as np
5import pandas as pd
6from matplotlib import pyplot as plt
8from ...util.plot import ScatterPlot, HistogramPlot, Plot, HeatMapPlot
9from ...util.string import ToStringMixin, dict_string
10from ...vector_model import VectorModel
12# Note: in the 2020.2 version of PyCharm passing strings to bound is highlighted as error
13# It does not cause runtime errors and the static type checker ignores the bound anyway, so it does not matter for now.
14# However, this might cause problems with type checking in the future. Therefore, I moved the definition of TEvalStats
15# below the definition of EvalStats. Unfortunately, the dependency in generics between EvalStats and Metric
16# does not allow to define both, TMetric and TEvalStats, properly. For now we have to leave it with the bound as string
17# and hope for the best in the future
18TMetric = TypeVar("TMetric", bound="Metric")
19TVectorModel = TypeVar("TVectorModel", bound=VectorModel)
21PredictionArray = Union[np.ndarray, pd.Series, pd.DataFrame, list]
24class EvalStats(Generic[TMetric], ToStringMixin):
25 def __init__(self, metrics: List[TMetric], additional_metrics: List[TMetric] = None):
26 if len(metrics) == 0:
27 raise ValueError("No metrics provided")
28 self.metrics = metrics
29 # Implementations of EvalStats will typically provide default metrics, therefore we include
30 # the possibility for passing additional metrics here
31 if additional_metrics is not None:
32 self.metrics = self.metrics + additional_metrics
33 self.name = None
35 def set_name(self, name: str):
36 self.name = name
38 def add_metric(self, metric: TMetric):
39 self.metrics.append(metric)
41 def compute_metric_value(self, metric: TMetric) -> float:
42 return metric.compute_value_for_eval_stats(self)
44 def metrics_dict(self) -> Dict[str, float]:
45 """
46 Computes all metrics
48 :return: a dictionary mapping metric names to values
49 """
50 d = {}
51 for metric in self.metrics:
52 d[metric.name] = self.compute_metric_value(metric)
53 return d
55 def get_all(self) -> Dict[str, float]:
56 """Alias for metricsDict; may be deprecated in the future"""
57 return self.metrics_dict()
59 def _tostring_object_info(self) -> str:
60 return dict_string(self.metrics_dict())
63TEvalStats = TypeVar("TEvalStats", bound=EvalStats)
66class Metric(Generic[TEvalStats], ABC):
67 name: str
69 def __init__(self, name: str = None, bounds: Optional[Tuple[float, float]] = None):
70 """
71 :param name: the name of the metric; if None use the class' name attribute
72 :param bounds: the minimum and maximum values the metric can take on (or None if the bounds are not specified)
73 """
74 # this raises an attribute error if a subclass does not specify a name as a static attribute nor as parameter
75 self.name = name if name is not None else self.__class__.name
76 self.bounds = bounds
78 @abstractmethod
79 def compute_value_for_eval_stats(self, eval_stats: TEvalStats) -> float:
80 pass
82 def get_paired_metrics(self) -> List[TMetric]:
83 """
84 Gets a list of metrics that should be considered together with this metric (e.g. for paired visualisations/plots).
85 The direction of the pairing should be such that if this metric is "x", the other is "y" for x-y type visualisations.
87 :return: a list of metrics
88 """
89 return []
91 def has_finite_bounds(self) -> bool:
92 return self.bounds is not None and not any((np.isinf(x) for x in self.bounds))
95class EvalStatsCollection(Generic[TEvalStats, TMetric], ABC):
96 def __init__(self, eval_stats_list: List[TEvalStats]):
97 self.statsList = eval_stats_list
98 metric_names_set = None
99 metrics_list = []
100 for es in eval_stats_list:
101 metrics = es.metrics_dict()
102 current_metric_names_set = set(metrics.keys())
103 if metric_names_set is None:
104 metric_names_set = current_metric_names_set
105 else:
106 if metric_names_set != current_metric_names_set:
107 raise Exception(f"Inconsistent set of metrics in evaluation stats collection: "
108 f"Got {metric_names_set} for one instance, {current_metric_names_set} for another")
109 metrics_list.append(metrics)
110 metric_names = sorted(metrics_list[0].keys())
111 self._valuesByMetricName = {metric: [d[metric] for d in metrics_list] for metric in metric_names}
112 self._metrics: List[TMetric] = eval_stats_list[0].metrics
114 def get_values(self, metric_name: str):
115 return self._valuesByMetricName[metric_name]
117 def get_metric_names(self) -> List[str]:
118 return list(self._valuesByMetricName.keys())
120 def get_metrics(self) -> List[TMetric]:
121 return self._metrics
123 def get_metric_by_name(self, name: str) -> Optional[TMetric]:
124 for m in self._metrics:
125 if m.name == name:
126 return m
127 return None
129 def has_metric(self, metric: Union[Metric, str]) -> bool:
130 if type(metric) != str:
131 metric = metric.name
132 return metric in self._valuesByMetricName
134 def agg_metrics_dict(self, agg_fns=(np.mean, np.std)) -> Dict[str, float]:
135 agg = {}
136 for metric, values in self._valuesByMetricName.items():
137 for agg_fn in agg_fns:
138 agg[f"{agg_fn.__name__}[{metric}]"] = float(agg_fn(values))
139 return agg
141 def mean_metrics_dict(self) -> Dict[str, float]:
142 metrics = {metric: np.mean(values) for (metric, values) in self._valuesByMetricName.items()}
143 return metrics
145 def plot_distribution(self, metric_name: str, subtitle: Optional[str] = None, bins=None, kde=False, cdf=False,
146 cdf_complementary=False, stat="proportion", **kwargs) -> plt.Figure:
147 """
148 Plots the distribution of a metric as a histogram
150 :param metric_name: name of the metric for which to plot the distribution (histogram) across evaluations
151 :param subtitle: the subtitle to add, if any
152 :param bins: the histogram bins (number of bins or boundaries); metrics bounds will be used to define the x limits.
153 If None, use 'auto' bins
154 :param kde: whether to add a kernel density estimator plot
155 :param cdf: whether to add the cumulative distribution function (cdf)
156 :param cdf_complementary: whether to plot, if ``cdf`` is True, the complementary cdf instead of the regular cdf
157 :param stat: the statistic to compute for each bin ('percent', 'probability'='proportion', 'count', 'frequency' or 'density'),
158 y-axis value
159 :param kwargs: additional parameters to pass to seaborn.histplot (see https://seaborn.pydata.org/generated/seaborn.histplot.html)
160 :return: the plot
161 """
162 # define bins based on metric bounds where available
163 x_tick = None
164 if bins is None or type(bins) == int:
165 metric = self.get_metric_by_name(metric_name)
166 if metric.bounds == (0, 1):
167 x_tick = 0.1
168 if bins is None:
169 num_bins = 10 if cdf else 20
170 else:
171 num_bins = bins
172 bins = np.linspace(0, 1, num_bins+1)
173 else:
174 bins = "auto"
176 values = self._valuesByMetricName[metric_name]
177 title = metric_name
178 if subtitle is not None:
179 title += "\n" + subtitle
180 plot = HistogramPlot(values, bins=bins, stat=stat, kde=kde, cdf=cdf, cdf_complementary=cdf_complementary, **kwargs).title(title)
181 if x_tick is not None:
182 plot.xtick_major(x_tick)
183 return plot.fig
185 def _plot_xy(self, metric_name_x, metric_name_y, plot_factory: Callable[[Sequence, Sequence], Plot], adjust_bounds: bool) -> plt.Figure:
186 def axlim(bounds):
187 min_value, max_value = bounds
188 diff = max_value - min_value
189 return (min_value - 0.05 * diff, max_value + 0.05 * diff)
191 x = self._valuesByMetricName[metric_name_x]
192 y = self._valuesByMetricName[metric_name_y]
193 plot = plot_factory(x, y)
194 plot.xlabel(metric_name_x)
195 plot.ylabel(metric_name_y)
196 mx = self.get_metric_by_name(metric_name_x)
197 if adjust_bounds and mx.has_finite_bounds():
198 plot.xlim(*axlim(mx.bounds))
199 my = self.get_metric_by_name(metric_name_y)
200 if adjust_bounds and my.has_finite_bounds():
201 plot.ylim(*axlim(my.bounds))
202 return plot.fig
204 def plot_scatter(self, metric_name_x: str, metric_name_y: str) -> plt.Figure:
205 return self._plot_xy(metric_name_x, metric_name_y, ScatterPlot, adjust_bounds=True)
207 def plot_heat_map(self, metric_name_x: str, metric_name_y: str) -> plt.Figure:
208 return self._plot_xy(metric_name_x, metric_name_y, HeatMapPlot, adjust_bounds=False)
210 def to_data_frame(self) -> pd.DataFrame:
211 """
212 :return: a DataFrame with the evaluation metrics from all contained EvalStats objects;
213 the EvalStats' name field being used as the index if it is set
214 """
215 data = dict(self._valuesByMetricName)
216 index = [stats.name for stats in self.statsList]
217 if len([n for n in index if n is not None]) == 0:
218 index = None
219 return pd.DataFrame(data, index=index)
221 def get_global_stats(self) -> TEvalStats:
222 """
223 Alias for `getCombinedEvalStats`
224 """
225 return self.get_combined_eval_stats()
227 @abstractmethod
228 def get_combined_eval_stats(self) -> TEvalStats:
229 """
230 :return: an EvalStats object that combines the data from all contained EvalStats objects
231 """
232 pass
234 def __str__(self):
235 return f"{self.__class__.__name__}[" + \
236 ", ".join([f"{key}={self.agg_metrics_dict()[key]:.4f}" for key in self._valuesByMetricName]) + "]"
239class PredictionEvalStats(EvalStats[TMetric], ABC):
240 """
241 Collects data for the evaluation of predicted values (including multi-dimensional predictions)
242 and computes corresponding metrics
243 """
244 def __init__(self, y_predicted: Optional[PredictionArray], y_true: Optional[PredictionArray],
245 metrics: List[TMetric], additional_metrics: List[TMetric] = None):
246 """
247 :param y_predicted: sequence of predicted values, or, in case of multi-dimensional predictions, either a data frame with
248 one column per dimension or a nested sequence of values
249 :param y_true: sequence of ground truth labels of same shape as y_predicted
250 :param metrics: list of metrics to be computed on the provided data
251 :param additional_metrics: the metrics to additionally compute. This should only be provided if metrics is None
252 """
253 self.y_true = []
254 self.y_predicted = []
255 self.y_true_multidim = None
256 self.y_predicted_multidim = None
257 if y_predicted is not None:
258 self.add_all(y_predicted, y_true)
259 super().__init__(metrics, additional_metrics=additional_metrics)
261 def add(self, y_predicted, y_true) -> None:
262 """
263 Adds a single pair of values to the evaluation
265 :param y_predicted: the value predicted by the model
266 :param y_true: the true value
267 """
268 self.y_true.append(y_true)
269 self.y_predicted.append(y_predicted)
271 def add_all(self, y_predicted: PredictionArray, y_true: PredictionArray) -> None:
272 """
273 :param y_predicted: sequence of predicted values, or, in case of multi-dimensional predictions, either a data frame with
274 one column per dimension or a nested sequence of values
275 :param y_true: sequence of ground truth labels of same shape as y_predicted
276 """
277 def is_sequence(x):
278 return isinstance(x, pd.Series) or isinstance(x, list) or isinstance(x, np.ndarray)
280 if is_sequence(y_predicted) and is_sequence(y_true):
281 a, b = len(y_predicted), len(y_true)
282 if a != b:
283 raise Exception(f"Lengths differ (predicted {a}, truth {b})")
284 if a > 0:
285 first_item = y_predicted.iloc[0] if isinstance(y_predicted, pd.Series) else y_predicted[0]
286 is_nested_sequence = is_sequence(first_item)
287 if is_nested_sequence:
288 for y_true_i, y_predicted_i in zip(y_true, y_predicted):
289 self.add_all(y_predicted=y_predicted_i, y_true=y_true_i)
290 else:
291 self.y_true.extend(y_true)
292 self.y_predicted.extend(y_predicted)
293 elif isinstance(y_predicted, pd.DataFrame) and isinstance(y_true, pd.DataFrame):
294 # keep track of multidimensional data (to be used later in getEvalStatsCollection)
295 y_predicted_multidim = y_predicted.values
296 y_true_multidim = y_true.values
297 dim = y_predicted_multidim.shape[1]
298 if dim != y_true_multidim.shape[1]:
299 raise Exception("Dimension mismatch")
300 if self.y_true_multidim is None:
301 self.y_predicted_multidim = [[] for _ in range(dim)]
302 self.y_true_multidim = [[] for _ in range(dim)]
303 if len(self.y_predicted_multidim) != dim:
304 raise Exception("Dimension mismatch")
305 for i in range(dim):
306 self.y_predicted_multidim[i].extend(y_predicted_multidim[:, i])
307 self.y_true_multidim[i].extend(y_true_multidim[:, i])
308 # convert to flat data for this stats object
309 y_predicted = y_predicted_multidim.reshape(-1)
310 y_true = y_true_multidim.reshape(-1)
311 self.y_true.extend(y_true)
312 self.y_predicted.extend(y_predicted)
313 else:
314 raise Exception(f"Unhandled data types: {type(y_predicted)}, {type(y_true)}")
316 def _tostring_object_info(self) -> str:
317 return f"{super()._tostring_object_info()}, N={len(self.y_predicted)}"
320def mean_stats(eval_stats_list: Sequence[EvalStats]) -> Dict[str, float]:
321 """
322 For a list of EvalStats objects compute the mean values of all metrics in a dictionary.
323 Assumes that all provided EvalStats have the same metrics
324 """
325 dicts = [s.metrics_dict() for s in eval_stats_list]
326 metrics = dicts[0].keys()
327 return {m: np.mean([d[m] for d in dicts]) for m in metrics}
330class EvalStatsPlot(Generic[TEvalStats], ABC):
331 @abstractmethod
332 def create_figure(self, eval_stats: TEvalStats, subtitle: str) -> Optional[plt.Figure]:
333 """
334 :param eval_stats: the evaluation stats from which to generate the plot
335 :param subtitle: the plot's subtitle
336 :return: the figure or None if this plot is not applicable/cannot be created
337 """
338 pass