Coverage for src/sensai/evaluation/eval_stats/eval_stats_classification.py: 38%
404 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import logging
2from abc import ABC, abstractmethod
3from typing import List, Sequence, Optional, Dict, Any, Tuple
5import matplotlib.ticker as plticker
6import numpy as np
7import pandas as pd
8import sklearn
9from matplotlib import pyplot as plt
10from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, precision_recall_curve, \
11 balanced_accuracy_score, f1_score
13from .eval_stats_base import PredictionArray, PredictionEvalStats, EvalStatsCollection, Metric, EvalStatsPlot, TMetric
14from ...util.aggregation import RelativeFrequencyCounter
15from ...util.pickle import getstate
16from ...util.plot import plot_matrix
18log = logging.getLogger(__name__)
21GUESS = ("__guess",)
22BINARY_CLASSIFICATION_POSITIVE_LABEL_CANDIDATES = [1, True, "1", "True"]
25class ClassificationMetric(Metric["ClassificationEvalStats"], ABC):
26 requires_probabilities = False
28 def __init__(self, name: Optional[str] = None, bounds: Tuple[float, float] = (0, 1), requires_probabilities: Optional[bool] = None):
29 """
30 :param name: the name of the metric; if None use the class' name attribute
31 :param bounds: the minimum and maximum values the metric can take on
32 """
33 super().__init__(name=name, bounds=bounds)
34 self.requires_probabilities = requires_probabilities \
35 if requires_probabilities is not None \
36 else self.__class__.requires_probabilities
38 def compute_value_for_eval_stats(self, eval_stats: "ClassificationEvalStats"):
39 return self.compute_value(eval_stats.y_true, eval_stats.y_predicted, eval_stats.y_predicted_class_probabilities)
41 def compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: Optional[PredictionArray] = None):
42 if self.requires_probabilities and y_predicted_class_probabilities is None:
43 raise ValueError(f"{self} requires class probabilities")
44 return self._compute_value(y_true, y_predicted, y_predicted_class_probabilities)
46 @abstractmethod
47 def _compute_value(self, y_true, y_predicted, y_predicted_class_probabilities):
48 pass
51class ClassificationMetricAccuracy(ClassificationMetric):
52 name = "accuracy"
54 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
55 return accuracy_score(y_true=y_true, y_pred=y_predicted)
58class ClassificationMetricBalancedAccuracy(ClassificationMetric):
59 name = "balancedAccuracy"
61 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
62 return balanced_accuracy_score(y_true=y_true, y_pred=y_predicted)
65class ClassificationMetricAccuracyWithoutLabels(ClassificationMetric):
66 """
67 Accuracy score with set of data points limited to the ones where the ground truth label is not one of the given labels
68 """
69 def __init__(self, *labels: Any, probability_threshold=None, zero_value=0.0):
70 """
71 :param labels: one or more labels which are not to be considered (all data points where the ground truth is
72 one of these labels will be ignored)
73 :param probability_threshold: a probability threshold: the probability of the most likely class must be at least this value for a
74 data point to be considered in the metric computation (analogous to
75 :class:`ClassificationMetricAccuracyMaxProbabilityBeyondThreshold`)
76 :param zero_value: the metric value to assume for the case where the condition never applies (no countable instances without
77 the given label/beyond the given threshold)
78 """
79 if probability_threshold is not None:
80 name_add = f", p_max >= {probability_threshold}"
81 else:
82 name_add = ""
83 name = f"{ClassificationMetricAccuracy.name}Without[{','.join(map(str, labels))}{name_add}]"
84 super().__init__(name, requires_probabilities=probability_threshold is not None)
85 self.labels = set(labels)
86 self.probability_threshold = probability_threshold
87 self.zero_value = zero_value
89 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
90 y_true = np.array(y_true)
91 y_predicted = np.array(y_predicted)
92 indices = []
93 for i, (true_label, predicted_label) in enumerate(zip(y_true, y_predicted)):
94 if true_label not in self.labels:
95 if self.probability_threshold is not None:
96 if y_predicted_class_probabilities[predicted_label].iloc[i] < self.probability_threshold:
97 continue
98 indices.append(i)
99 if len(indices) == 0:
100 return self.zero_value
101 return accuracy_score(y_true=y_true[indices], y_pred=y_predicted[indices])
103 def get_paired_metrics(self) -> List[TMetric]:
104 if self.probability_threshold is not None:
105 return [ClassificationMetricRelFreqMaxProbabilityBeyondThreshold(self.probability_threshold)]
106 else:
107 return []
110class ClassificationMetricGeometricMeanOfTrueClassProbability(ClassificationMetric):
111 name = "geoMeanTrueClassProb"
112 requires_probabilities = True
114 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
115 y_predicted_proba_true_class = np.zeros(len(y_true))
116 for i in range(len(y_true)):
117 true_class = y_true[i]
118 if true_class not in y_predicted_class_probabilities.columns:
119 y_predicted_proba_true_class[i] = 0
120 else:
121 y_predicted_proba_true_class[i] = y_predicted_class_probabilities[true_class].iloc[i]
122 # the 1e-3 below prevents lp = -inf due to single entries with y_predicted_proba_true_class=0
123 lp = np.log(np.maximum(1e-3, y_predicted_proba_true_class))
124 return np.exp(lp.sum() / len(lp))
127class ClassificationMetricTopNAccuracy(ClassificationMetric):
128 requires_probabilities = True
130 def __init__(self, n: int):
131 self.n = n
132 super().__init__(name=f"top{n}Accuracy")
134 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
135 labels = y_predicted_class_probabilities.columns
136 cnt = 0
137 for i, rowValues in enumerate(y_predicted_class_probabilities.values.tolist()):
138 pairs = sorted(zip(labels, rowValues), key=lambda x: x[1], reverse=True)
139 if y_true[i] in (x[0] for x in pairs[:self.n]):
140 cnt += 1
141 return cnt / len(y_true)
144class ClassificationMetricAccuracyMaxProbabilityBeyondThreshold(ClassificationMetric):
145 """
146 Accuracy limited to cases where the probability of the most likely class is at least a given threshold
147 """
148 requires_probabilities = True
150 def __init__(self, threshold: float, zero_value=0.0):
151 """
152 :param threshold: minimum probability of the most likely class
153 :param zero_value: the value of the metric for the case where the probability of the most likely class never reaches the threshold
154 """
155 self.threshold = threshold
156 self.zeroValue = zero_value
157 super().__init__(name=f"accuracy[p_max >= {threshold}]")
159 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
160 labels = y_predicted_class_probabilities.columns
161 label_to_col_idx = {l: i for i, l in enumerate(labels)}
162 rel_freq = RelativeFrequencyCounter()
163 for i, probabilities in enumerate(y_predicted_class_probabilities.values.tolist()):
164 class_idx_predicted = np.argmax(probabilities)
165 prob_predicted = probabilities[class_idx_predicted]
166 if prob_predicted >= self.threshold:
167 class_idx_true = label_to_col_idx.get(y_true[i], -1) # -1 if true class is unknown to model (did not appear in training data)
168 rel_freq.count(class_idx_predicted == class_idx_true)
169 if rel_freq.num_total == 0:
170 return self.zeroValue
171 else:
172 return rel_freq.get_relative_frequency()
174 def get_paired_metrics(self) -> List[TMetric]:
175 return [ClassificationMetricRelFreqMaxProbabilityBeyondThreshold(self.threshold)]
178class ClassificationMetricRelFreqMaxProbabilityBeyondThreshold(ClassificationMetric):
179 """
180 Relative frequency of cases where the probability of the most likely class is at least a given threshold
181 """
182 requires_probabilities = True
184 def __init__(self, threshold: float):
185 """
186 :param threshold: minimum probability of the most likely class
187 """
188 self.threshold = threshold
189 super().__init__(name=f"relFreq[p_max >= {threshold}]")
191 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
192 rel_freq = RelativeFrequencyCounter()
193 for i, probabilities in enumerate(y_predicted_class_probabilities.values.tolist()):
194 p_max = np.max(probabilities)
195 rel_freq.count(p_max >= self.threshold)
196 return rel_freq.get_relative_frequency()
199class BinaryClassificationMetric(ClassificationMetric, ABC):
200 def __init__(self, positive_class_label, name: str = None):
201 name = name if name is not None else self.__class__.name
202 if positive_class_label not in BINARY_CLASSIFICATION_POSITIVE_LABEL_CANDIDATES:
203 name = f"{name}[{positive_class_label}]"
204 super().__init__(name)
205 self.positiveClassLabel = positive_class_label
208class BinaryClassificationMetricPrecision(BinaryClassificationMetric):
209 name = "precision"
211 def __init__(self, positive_class_label):
212 super().__init__(positive_class_label)
214 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
215 return precision_score(y_true, y_predicted, pos_label=self.positiveClassLabel, zero_division=0)
217 def get_paired_metrics(self) -> List[BinaryClassificationMetric]:
218 return [BinaryClassificationMetricRecall(self.positiveClassLabel)]
221class BinaryClassificationMetricRecall(BinaryClassificationMetric):
222 name = "recall"
224 def __init__(self, positive_class_label):
225 super().__init__(positive_class_label)
227 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
228 return recall_score(y_true, y_predicted, pos_label=self.positiveClassLabel)
231class BinaryClassificationMetricF1Score(BinaryClassificationMetric):
232 name = "F1"
234 def __init__(self, positive_class_label):
235 super().__init__(positive_class_label)
237 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
238 return f1_score(y_true, y_predicted, pos_label=self.positiveClassLabel)
241class BinaryClassificationMetricRecallForPrecision(BinaryClassificationMetric):
242 """
243 Computes the maximum recall that can be achieved (by varying the decision threshold) in cases where at least the given precision
244 is reached. The given precision may not be achievable at all, in which case the metric value is ``zeroValue``.
245 """
246 def __init__(self, precision: float, positive_class_label, zero_value=0.0):
247 """
248 :param precision: the minimum precision value that must be reached
249 :param positive_class_label: the positive class label
250 :param zero_value: the value to return for the case where the minimum precision is never reached
251 """
252 self.minPrecision = precision
253 self.zero_value = zero_value
254 super().__init__(positive_class_label, name=f"recallForPrecision[{precision}]")
256 def compute_value_for_eval_stats(self, eval_stats: "ClassificationEvalStats"):
257 var_data = eval_stats.get_binary_classification_probability_threshold_variation_data()
258 best_recall = None
259 for c in var_data.counts:
260 precision = c.get_precision()
261 if precision >= self.minPrecision:
262 recall = c.get_recall()
263 if best_recall is None or recall > best_recall:
264 best_recall = recall
265 return self.zero_value if best_recall is None else best_recall
267 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
268 raise NotImplementedError(f"{self.__class__.__qualname__} only supports computeValueForEvalStats")
271class BinaryClassificationMetricPrecisionThreshold(BinaryClassificationMetric):
272 """
273 Precision for the case where predictions are considered "positive" if predicted probability of the positive class is beyond the
274 given threshold
275 """
276 requires_probabilities = True
278 def __init__(self, threshold: float, positive_class_label: Any, zero_value=0.0):
279 """
280 :param threshold: the minimum predicted probability of the positive class for the prediction to be considered "positive"
281 :param zero_value: the value of the metric for the case where a positive class probability beyond the threshold is never predicted
282 (denominator = 0)
283 """
284 self.threshold = threshold
285 self.zero_value = zero_value
286 super().__init__(positive_class_label, name=f"precision[{threshold}]")
288 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
289 rel_freq_correct = RelativeFrequencyCounter()
290 class_idx_positive = list(y_predicted_class_probabilities.columns).index(self.positiveClassLabel)
291 for i, (probabilities, classLabel_true) in enumerate(zip(y_predicted_class_probabilities.values.tolist(), y_true)):
292 prob_predicted = probabilities[class_idx_positive]
293 if prob_predicted >= self.threshold:
294 rel_freq_correct.count(classLabel_true == self.positiveClassLabel)
295 f = rel_freq_correct.get_relative_frequency()
296 return f if f is not None else self.zero_value
298 def get_paired_metrics(self) -> List[BinaryClassificationMetric]:
299 return [BinaryClassificationMetricRecallThreshold(self.threshold, self.positiveClassLabel)]
302class BinaryClassificationMetricRecallThreshold(BinaryClassificationMetric):
303 """
304 Recall for the case where predictions are considered "positive" if predicted probability of the positive class is beyond the
305 given threshold
306 """
307 requires_probabilities = True
309 def __init__(self, threshold: float, positive_class_label: Any, zero_value=0.0):
310 """
311 :param threshold: the minimum predicted probability of the positive class for the prediction to be considered "positive"
312 :param zero_value: the value of the metric for the case where there are no positive instances in the data set (denominator = 0)
313 """
314 self.threshold = threshold
315 self.zero_value = zero_value
316 super().__init__(positive_class_label, name=f"recall[{threshold}]")
318 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray):
319 rel_freq_recalled = RelativeFrequencyCounter()
320 class_idx_positive = list(y_predicted_class_probabilities.columns).index(self.positiveClassLabel)
321 for i, (probabilities, classLabel_true) in enumerate(zip(y_predicted_class_probabilities.values.tolist(), y_true)):
322 if self.positiveClassLabel == classLabel_true:
323 prob_predicted = probabilities[class_idx_positive]
324 rel_freq_recalled.count(prob_predicted >= self.threshold)
325 f = rel_freq_recalled.get_relative_frequency()
326 return f if f is not None else self.zero_value
329DEFAULT_MULTICLASS_CLASSIFICATION_METRICS = (ClassificationMetricAccuracy(), ClassificationMetricBalancedAccuracy(),
330 ClassificationMetricGeometricMeanOfTrueClassProbability())
333def create_default_binary_classification_metrics(positive_class_label: Any) -> List[BinaryClassificationMetric]:
334 return [BinaryClassificationMetricPrecision(positive_class_label), BinaryClassificationMetricRecall(positive_class_label),
335 BinaryClassificationMetricF1Score(positive_class_label)]
338class ClassificationEvalStats(PredictionEvalStats["ClassificationMetric"]):
339 def __init__(self, y_predicted: Optional[PredictionArray] = None,
340 y_true: Optional[PredictionArray] = None,
341 y_predicted_class_probabilities: Optional[pd.DataFrame] = None,
342 labels: Optional[PredictionArray] = None,
343 metrics: Optional[Sequence["ClassificationMetric"]] = None,
344 additional_metrics: Optional[Sequence["ClassificationMetric"]] = None,
345 binary_positive_label=GUESS):
346 """
347 :param y_predicted: the predicted class labels
348 :param y_true: the true class labels
349 :param y_predicted_class_probabilities: a data frame whose columns are the class labels and whose values are probabilities
350 :param labels: the list of class labels
351 :param metrics: the metrics to compute for evaluation; if None, use default metrics
352 (see DEFAULT_MULTICLASS_CLASSIFICATION_METRICS and :func:`create_default_binary_classification_metrics`)
353 :param additional_metrics: the metrics to additionally compute
354 :param binary_positive_label: the label of the positive class for the case where it is a binary classification, adding further
355 binary metrics by default;
356 if GUESS (default), check `labels` (if length 2) for occurrence of one of BINARY_CLASSIFICATION_POSITIVE_LABEL_CANDIDATES in
357 the respective order and use the first one found (if any);
358 if None, treat the problem as non-binary, regardless of the labels being used.
359 """
360 self.labels = labels
361 self.y_predicted_class_probabilities = y_predicted_class_probabilities
362 self.is_probabilities_available = y_predicted_class_probabilities is not None
363 if self.is_probabilities_available:
364 col_set = set(y_predicted_class_probabilities.columns)
365 if col_set != set(labels):
366 raise ValueError(f"Columns in class probabilities data frame ({y_predicted_class_probabilities.columns}) do not "
367 f"correspond to labels ({labels}")
368 if len(y_predicted_class_probabilities) != len(y_true):
369 raise ValueError("Row count in class probabilities data frame does not match ground truth")
371 num_labels = len(labels)
372 if binary_positive_label == GUESS:
373 found_candidate_label = False
374 if num_labels == 2:
375 for c in BINARY_CLASSIFICATION_POSITIVE_LABEL_CANDIDATES:
376 if c in labels:
377 binary_positive_label = c
378 found_candidate_label = True
379 break
380 if not found_candidate_label:
381 binary_positive_label = None
382 elif binary_positive_label is not None:
383 if num_labels != 2:
384 log.warning(f"Passed binaryPositiveLabel for non-binary classification (labels={self.labels})")
385 if binary_positive_label not in self.labels:
386 log.warning(f"The binary positive label {binary_positive_label} does not appear in labels={labels}")
387 if num_labels == 2 and binary_positive_label is None:
388 log.warning(f"Binary classification (labels={labels}) without specification of positive class label; "
389 f"binary classification metrics will not be considered")
390 self.binary_positive_label = binary_positive_label
391 self.is_binary = binary_positive_label is not None
393 if metrics is None:
394 metrics = list(DEFAULT_MULTICLASS_CLASSIFICATION_METRICS)
395 if self.is_binary:
396 metrics.extend(create_default_binary_classification_metrics(self.binary_positive_label))
398 metrics = list(metrics)
399 if additional_metrics is not None:
400 for m in additional_metrics:
401 if not self.is_probabilities_available and m.requires_probabilities:
402 raise ValueError(f"Additional metric {m} not supported, as class probabilities were not provided")
404 super().__init__(y_predicted, y_true, metrics, additional_metrics=additional_metrics)
406 # transient members
407 self._binary_classification_probability_threshold_variation_data = None
409 def __getstate__(self):
410 return getstate(ClassificationEvalStats, self, transient_properties=["_binaryClassificationProbabilityThresholdVariationData"])
412 def get_confusion_matrix(self) -> "ConfusionMatrix":
413 return ConfusionMatrix(self.y_true, self.y_predicted)
415 def get_binary_classification_probability_threshold_variation_data(self) -> "BinaryClassificationProbabilityThresholdVariationData":
416 if self._binary_classification_probability_threshold_variation_data is None:
417 self._binary_classification_probability_threshold_variation_data = BinaryClassificationProbabilityThresholdVariationData(self)
418 return self._binary_classification_probability_threshold_variation_data
420 def get_accuracy(self):
421 return self.compute_metric_value(ClassificationMetricAccuracy())
423 def metrics_dict(self) -> Dict[str, float]:
424 d = {}
425 for metric in self.metrics:
426 if not metric.requires_probabilities or self.is_probabilities_available:
427 d[metric.name] = self.compute_metric_value(metric)
428 return d
430 def get_misclassified_indices(self) -> List[int]:
431 return [i for i, (predClass, trueClass) in enumerate(zip(self.y_predicted, self.y_true)) if predClass != trueClass]
433 def plot_confusion_matrix(self, normalize=True, title_add: str = None) -> plt.Figure:
434 # based on https://scikit-learn.org/0.20/auto_examples/model_selection/plot_confusion_matrix.html
435 cm = self.get_confusion_matrix()
436 return cm.plot(normalize=normalize, title_add=title_add)
438 def plot_precision_recall_curve(self, title_add: str = None) -> plt.Figure:
439 from sklearn.metrics import PrecisionRecallDisplay # only supported by newer versions of sklearn
440 if not self.is_probabilities_available:
441 raise Exception("Precision-recall curve requires probabilities")
442 if not self.is_binary:
443 raise Exception("Precision-recall curve is not applicable to non-binary classification")
444 probabilities = self.y_predicted_class_probabilities[self.binary_positive_label]
445 precision, recall, thresholds = precision_recall_curve(y_true=self.y_true, probas_pred=probabilities,
446 pos_label=self.binary_positive_label)
447 disp = PrecisionRecallDisplay(precision, recall)
448 disp.plot()
449 ax: plt.Axes = disp.ax_
450 ax.set_xlabel("recall")
451 ax.set_ylabel("precision")
452 title = "Precision-Recall Curve"
453 if title_add is not None:
454 title += "\n" + title_add
455 ax.set_title(title)
456 ax.xaxis.set_major_locator(plticker.MultipleLocator(base=0.1))
457 ax.yaxis.set_major_locator(plticker.MultipleLocator(base=0.1))
458 return disp.figure_
461class ClassificationEvalStatsCollection(EvalStatsCollection[ClassificationEvalStats, ClassificationMetric]):
462 def __init__(self, eval_stats_list: List[ClassificationEvalStats]):
463 super().__init__(eval_stats_list)
464 self.globalStats = None
466 def get_combined_eval_stats(self) -> ClassificationEvalStats:
467 """
468 Combines the data from all contained EvalStats objects into a single object.
469 Note that this is only possible if all EvalStats objects use the same set of class labels.
471 :return: an EvalStats object that combines the data from all contained EvalStats objects
472 """
473 if self.globalStats is None:
474 y_true = np.concatenate([evalStats.y_true for evalStats in self.statsList])
475 y_predicted = np.concatenate([evalStats.y_predicted for evalStats in self.statsList])
476 es0 = self.statsList[0]
477 if es0.y_predicted_class_probabilities is not None:
478 y_probs = pd.concat([evalStats.y_predicted_class_probabilities for evalStats in self.statsList])
479 labels = list(y_probs.columns)
480 else:
481 y_probs = None
482 labels = es0.labels
483 self.globalStats = ClassificationEvalStats(y_predicted=y_predicted, y_true=y_true, y_predicted_class_probabilities=y_probs,
484 labels=labels, binary_positive_label=es0.binary_positive_label, metrics=es0.metrics)
485 return self.globalStats
488class ConfusionMatrix:
489 def __init__(self, y_true: PredictionArray, y_predicted: PredictionArray):
490 self.labels = sklearn.utils.multiclass.unique_labels(y_true, y_predicted)
491 self.confusionMatrix = confusion_matrix(y_true, y_predicted, labels=self.labels)
493 def plot(self, normalize: bool = True, title_add: str = None):
494 title = 'Normalized Confusion Matrix' if normalize else 'Confusion Matrix (Counts)'
495 return plot_matrix(self.confusionMatrix, title, self.labels, self.labels, 'true class', 'predicted class', normalize=normalize,
496 title_add=title_add)
499class BinaryClassificationCounts:
500 def __init__(self, is_positive_prediction: Sequence[bool], is_positive_ground_truth: Sequence[bool], zero_denominator_metric_value: float = 0.):
501 """
502 :param is_positive_prediction: the sequence of Booleans indicating whether the model predicted the positive class
503 :param is_positive_ground_truth: the sequence of Booleans indicating whether the true class is the positive class
504 :param zero_denominator_metric_value: the result to return for metrics such as precision and recall in case the denominator
505 is zero (i.e. zero counted cases)
506 """
507 self.zeroDenominatorMetricValue = zero_denominator_metric_value
508 self.tp = 0
509 self.tn = 0
510 self.fp = 0
511 self.fn = 0
512 for predPositive, gtPositive in zip(is_positive_prediction, is_positive_ground_truth):
513 if gtPositive:
514 if predPositive:
515 self.tp += 1
516 else:
517 self.fn += 1
518 else:
519 if predPositive:
520 self.fp += 1
521 else:
522 self.tn += 1
524 @classmethod
525 def from_probability_threshold(cls, probabilities: Sequence[float], threshold: float, is_positive_ground_truth: Sequence[bool]) \
526 -> "BinaryClassificationCounts":
527 return cls([p >= threshold for p in probabilities], is_positive_ground_truth)
529 @classmethod
530 def from_eval_stats(cls, eval_stats: ClassificationEvalStats, threshold=0.5) -> "BinaryClassificationCounts":
531 if not eval_stats.is_binary:
532 raise ValueError("Probability threshold variation data can only be computed for binary classification problems")
533 if eval_stats.y_predicted_class_probabilities is None:
534 raise ValueError("No probability data")
535 pos_class_label = eval_stats.binary_positive_label
536 probs = eval_stats.y_predicted_class_probabilities[pos_class_label]
537 is_positive_gt = [gtLabel == pos_class_label for gtLabel in eval_stats.y_true]
538 return cls.from_probability_threshold(probabilities=probs, threshold=threshold, is_positive_ground_truth=is_positive_gt)
540 def _frac(self, numerator, denominator):
541 if denominator == 0:
542 return self.zeroDenominatorMetricValue
543 return numerator / denominator
545 def get_precision(self):
546 return self._frac(self.tp, self.tp + self.fp)
548 def get_recall(self):
549 return self._frac(self.tp, self.tp + self.fn)
551 def get_f1(self):
552 return self._frac(self.tp, self.tp + 0.5 * (self.fp + self.fn))
554 def get_rel_freq_positive(self):
555 positive = self.tp + self.fp
556 negative = self.tn + self.fn
557 return positive / (positive + negative)
560class BinaryClassificationProbabilityThresholdVariationData:
561 def __init__(self, eval_stats: ClassificationEvalStats):
562 self.thresholds = np.linspace(0, 1, 101)
563 self.counts: List[BinaryClassificationCounts] = []
564 for threshold in self.thresholds:
565 self.counts.append(BinaryClassificationCounts.from_eval_stats(eval_stats, threshold=threshold))
567 def plot_precision_recall(self, subtitle=None) -> plt.Figure:
568 fig = plt.figure()
569 title = "Probability Threshold-Dependent Precision & Recall"
570 if subtitle is not None:
571 title += "\n" + subtitle
572 plt.title(title)
573 plt.xlabel("probability threshold")
574 precision = [c.get_precision() for c in self.counts]
575 recall = [c.get_recall() for c in self.counts]
576 f1 = [c.get_f1() for c in self.counts]
577 rf_positive = [c.get_rel_freq_positive() for c in self.counts]
578 plt.plot(self.thresholds, precision, label="precision")
579 plt.plot(self.thresholds, recall, label="recall")
580 plt.plot(self.thresholds, f1, label="F1-score")
581 plt.plot(self.thresholds, rf_positive, label="rel. freq. positive")
582 plt.legend()
583 return fig
585 def plot_counts(self, subtitle=None):
586 fig = plt.figure()
587 title = "Probability Threshold-Dependent Counts"
588 if subtitle is not None:
589 title += "\n" + subtitle
590 plt.title(title)
591 plt.xlabel("probability threshold")
592 plt.stackplot(self.thresholds,
593 [c.tp for c in self.counts], [c.tn for c in self.counts], [c.fp for c in self.counts], [c.fn for c in self.counts],
594 labels=["true positives", "true negatives", "false positives", "false negatives"],
595 colors=["#4fa244", "#79c36f", "#a25344", "#c37d6f"])
596 plt.legend()
597 return fig
600class ClassificationEvalStatsPlot(EvalStatsPlot[ClassificationEvalStats], ABC):
601 pass
604class ClassificationEvalStatsPlotConfusionMatrix(ClassificationEvalStatsPlot):
605 def __init__(self, normalise=True):
606 self.normalise = normalise
608 def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> plt.Figure:
609 return eval_stats.plot_confusion_matrix(normalize=self.normalise, title_add=subtitle)
612class ClassificationEvalStatsPlotPrecisionRecall(ClassificationEvalStatsPlot):
613 def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]:
614 if not eval_stats.is_binary or not eval_stats.is_probabilities_available:
615 return None
616 return eval_stats.plot_precision_recall_curve(title_add=subtitle)
619class ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall(ClassificationEvalStatsPlot):
620 def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]:
621 if not eval_stats.is_binary or not eval_stats.is_probabilities_available:
622 return None
623 return eval_stats.get_binary_classification_probability_threshold_variation_data().plot_precision_recall(subtitle=subtitle)
626class ClassificationEvalStatsPlotProbabilityThresholdCounts(ClassificationEvalStatsPlot):
627 def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]:
628 if not eval_stats.is_binary or not eval_stats.is_probabilities_available:
629 return None
630 return eval_stats.get_binary_classification_probability_threshold_variation_data().plot_counts(subtitle=subtitle)