Coverage for src/sensai/evaluation/eval_stats/eval_stats_classification.py: 38%

404 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import logging 

2from abc import ABC, abstractmethod 

3from typing import List, Sequence, Optional, Dict, Any, Tuple 

4 

5import matplotlib.ticker as plticker 

6import numpy as np 

7import pandas as pd 

8import sklearn 

9from matplotlib import pyplot as plt 

10from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, precision_recall_curve, \ 

11 balanced_accuracy_score, f1_score 

12 

13from .eval_stats_base import PredictionArray, PredictionEvalStats, EvalStatsCollection, Metric, EvalStatsPlot, TMetric 

14from ...util.aggregation import RelativeFrequencyCounter 

15from ...util.pickle import getstate 

16from ...util.plot import plot_matrix 

17 

18log = logging.getLogger(__name__) 

19 

20 

21GUESS = ("__guess",) 

22BINARY_CLASSIFICATION_POSITIVE_LABEL_CANDIDATES = [1, True, "1", "True"] 

23 

24 

25class ClassificationMetric(Metric["ClassificationEvalStats"], ABC): 

26 requires_probabilities = False 

27 

28 def __init__(self, name: Optional[str] = None, bounds: Tuple[float, float] = (0, 1), requires_probabilities: Optional[bool] = None): 

29 """ 

30 :param name: the name of the metric; if None use the class' name attribute 

31 :param bounds: the minimum and maximum values the metric can take on 

32 """ 

33 super().__init__(name=name, bounds=bounds) 

34 self.requires_probabilities = requires_probabilities \ 

35 if requires_probabilities is not None \ 

36 else self.__class__.requires_probabilities 

37 

38 def compute_value_for_eval_stats(self, eval_stats: "ClassificationEvalStats"): 

39 return self.compute_value(eval_stats.y_true, eval_stats.y_predicted, eval_stats.y_predicted_class_probabilities) 

40 

41 def compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: Optional[PredictionArray] = None): 

42 if self.requires_probabilities and y_predicted_class_probabilities is None: 

43 raise ValueError(f"{self} requires class probabilities") 

44 return self._compute_value(y_true, y_predicted, y_predicted_class_probabilities) 

45 

46 @abstractmethod 

47 def _compute_value(self, y_true, y_predicted, y_predicted_class_probabilities): 

48 pass 

49 

50 

51class ClassificationMetricAccuracy(ClassificationMetric): 

52 name = "accuracy" 

53 

54 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

55 return accuracy_score(y_true=y_true, y_pred=y_predicted) 

56 

57 

58class ClassificationMetricBalancedAccuracy(ClassificationMetric): 

59 name = "balancedAccuracy" 

60 

61 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

62 return balanced_accuracy_score(y_true=y_true, y_pred=y_predicted) 

63 

64 

65class ClassificationMetricAccuracyWithoutLabels(ClassificationMetric): 

66 """ 

67 Accuracy score with set of data points limited to the ones where the ground truth label is not one of the given labels 

68 """ 

69 def __init__(self, *labels: Any, probability_threshold=None, zero_value=0.0): 

70 """ 

71 :param labels: one or more labels which are not to be considered (all data points where the ground truth is 

72 one of these labels will be ignored) 

73 :param probability_threshold: a probability threshold: the probability of the most likely class must be at least this value for a 

74 data point to be considered in the metric computation (analogous to 

75 :class:`ClassificationMetricAccuracyMaxProbabilityBeyondThreshold`) 

76 :param zero_value: the metric value to assume for the case where the condition never applies (no countable instances without 

77 the given label/beyond the given threshold) 

78 """ 

79 if probability_threshold is not None: 

80 name_add = f", p_max >= {probability_threshold}" 

81 else: 

82 name_add = "" 

83 name = f"{ClassificationMetricAccuracy.name}Without[{','.join(map(str, labels))}{name_add}]" 

84 super().__init__(name, requires_probabilities=probability_threshold is not None) 

85 self.labels = set(labels) 

86 self.probability_threshold = probability_threshold 

87 self.zero_value = zero_value 

88 

89 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

90 y_true = np.array(y_true) 

91 y_predicted = np.array(y_predicted) 

92 indices = [] 

93 for i, (true_label, predicted_label) in enumerate(zip(y_true, y_predicted)): 

94 if true_label not in self.labels: 

95 if self.probability_threshold is not None: 

96 if y_predicted_class_probabilities[predicted_label].iloc[i] < self.probability_threshold: 

97 continue 

98 indices.append(i) 

99 if len(indices) == 0: 

100 return self.zero_value 

101 return accuracy_score(y_true=y_true[indices], y_pred=y_predicted[indices]) 

102 

103 def get_paired_metrics(self) -> List[TMetric]: 

104 if self.probability_threshold is not None: 

105 return [ClassificationMetricRelFreqMaxProbabilityBeyondThreshold(self.probability_threshold)] 

106 else: 

107 return [] 

108 

109 

110class ClassificationMetricGeometricMeanOfTrueClassProbability(ClassificationMetric): 

111 name = "geoMeanTrueClassProb" 

112 requires_probabilities = True 

113 

114 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

115 y_predicted_proba_true_class = np.zeros(len(y_true)) 

116 for i in range(len(y_true)): 

117 true_class = y_true[i] 

118 if true_class not in y_predicted_class_probabilities.columns: 

119 y_predicted_proba_true_class[i] = 0 

120 else: 

121 y_predicted_proba_true_class[i] = y_predicted_class_probabilities[true_class].iloc[i] 

122 # the 1e-3 below prevents lp = -inf due to single entries with y_predicted_proba_true_class=0 

123 lp = np.log(np.maximum(1e-3, y_predicted_proba_true_class)) 

124 return np.exp(lp.sum() / len(lp)) 

125 

126 

127class ClassificationMetricTopNAccuracy(ClassificationMetric): 

128 requires_probabilities = True 

129 

130 def __init__(self, n: int): 

131 self.n = n 

132 super().__init__(name=f"top{n}Accuracy") 

133 

134 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

135 labels = y_predicted_class_probabilities.columns 

136 cnt = 0 

137 for i, rowValues in enumerate(y_predicted_class_probabilities.values.tolist()): 

138 pairs = sorted(zip(labels, rowValues), key=lambda x: x[1], reverse=True) 

139 if y_true[i] in (x[0] for x in pairs[:self.n]): 

140 cnt += 1 

141 return cnt / len(y_true) 

142 

143 

144class ClassificationMetricAccuracyMaxProbabilityBeyondThreshold(ClassificationMetric): 

145 """ 

146 Accuracy limited to cases where the probability of the most likely class is at least a given threshold 

147 """ 

148 requires_probabilities = True 

149 

150 def __init__(self, threshold: float, zero_value=0.0): 

151 """ 

152 :param threshold: minimum probability of the most likely class 

153 :param zero_value: the value of the metric for the case where the probability of the most likely class never reaches the threshold 

154 """ 

155 self.threshold = threshold 

156 self.zeroValue = zero_value 

157 super().__init__(name=f"accuracy[p_max >= {threshold}]") 

158 

159 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

160 labels = y_predicted_class_probabilities.columns 

161 label_to_col_idx = {l: i for i, l in enumerate(labels)} 

162 rel_freq = RelativeFrequencyCounter() 

163 for i, probabilities in enumerate(y_predicted_class_probabilities.values.tolist()): 

164 class_idx_predicted = np.argmax(probabilities) 

165 prob_predicted = probabilities[class_idx_predicted] 

166 if prob_predicted >= self.threshold: 

167 class_idx_true = label_to_col_idx.get(y_true[i], -1) # -1 if true class is unknown to model (did not appear in training data) 

168 rel_freq.count(class_idx_predicted == class_idx_true) 

169 if rel_freq.num_total == 0: 

170 return self.zeroValue 

171 else: 

172 return rel_freq.get_relative_frequency() 

173 

174 def get_paired_metrics(self) -> List[TMetric]: 

175 return [ClassificationMetricRelFreqMaxProbabilityBeyondThreshold(self.threshold)] 

176 

177 

178class ClassificationMetricRelFreqMaxProbabilityBeyondThreshold(ClassificationMetric): 

179 """ 

180 Relative frequency of cases where the probability of the most likely class is at least a given threshold 

181 """ 

182 requires_probabilities = True 

183 

184 def __init__(self, threshold: float): 

185 """ 

186 :param threshold: minimum probability of the most likely class 

187 """ 

188 self.threshold = threshold 

189 super().__init__(name=f"relFreq[p_max >= {threshold}]") 

190 

191 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

192 rel_freq = RelativeFrequencyCounter() 

193 for i, probabilities in enumerate(y_predicted_class_probabilities.values.tolist()): 

194 p_max = np.max(probabilities) 

195 rel_freq.count(p_max >= self.threshold) 

196 return rel_freq.get_relative_frequency() 

197 

198 

199class BinaryClassificationMetric(ClassificationMetric, ABC): 

200 def __init__(self, positive_class_label, name: str = None): 

201 name = name if name is not None else self.__class__.name 

202 if positive_class_label not in BINARY_CLASSIFICATION_POSITIVE_LABEL_CANDIDATES: 

203 name = f"{name}[{positive_class_label}]" 

204 super().__init__(name) 

205 self.positiveClassLabel = positive_class_label 

206 

207 

208class BinaryClassificationMetricPrecision(BinaryClassificationMetric): 

209 name = "precision" 

210 

211 def __init__(self, positive_class_label): 

212 super().__init__(positive_class_label) 

213 

214 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

215 return precision_score(y_true, y_predicted, pos_label=self.positiveClassLabel, zero_division=0) 

216 

217 def get_paired_metrics(self) -> List[BinaryClassificationMetric]: 

218 return [BinaryClassificationMetricRecall(self.positiveClassLabel)] 

219 

220 

221class BinaryClassificationMetricRecall(BinaryClassificationMetric): 

222 name = "recall" 

223 

224 def __init__(self, positive_class_label): 

225 super().__init__(positive_class_label) 

226 

227 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

228 return recall_score(y_true, y_predicted, pos_label=self.positiveClassLabel) 

229 

230 

231class BinaryClassificationMetricF1Score(BinaryClassificationMetric): 

232 name = "F1" 

233 

234 def __init__(self, positive_class_label): 

235 super().__init__(positive_class_label) 

236 

237 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

238 return f1_score(y_true, y_predicted, pos_label=self.positiveClassLabel) 

239 

240 

241class BinaryClassificationMetricRecallForPrecision(BinaryClassificationMetric): 

242 """ 

243 Computes the maximum recall that can be achieved (by varying the decision threshold) in cases where at least the given precision 

244 is reached. The given precision may not be achievable at all, in which case the metric value is ``zeroValue``. 

245 """ 

246 def __init__(self, precision: float, positive_class_label, zero_value=0.0): 

247 """ 

248 :param precision: the minimum precision value that must be reached 

249 :param positive_class_label: the positive class label 

250 :param zero_value: the value to return for the case where the minimum precision is never reached 

251 """ 

252 self.minPrecision = precision 

253 self.zero_value = zero_value 

254 super().__init__(positive_class_label, name=f"recallForPrecision[{precision}]") 

255 

256 def compute_value_for_eval_stats(self, eval_stats: "ClassificationEvalStats"): 

257 var_data = eval_stats.get_binary_classification_probability_threshold_variation_data() 

258 best_recall = None 

259 for c in var_data.counts: 

260 precision = c.get_precision() 

261 if precision >= self.minPrecision: 

262 recall = c.get_recall() 

263 if best_recall is None or recall > best_recall: 

264 best_recall = recall 

265 return self.zero_value if best_recall is None else best_recall 

266 

267 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

268 raise NotImplementedError(f"{self.__class__.__qualname__} only supports computeValueForEvalStats") 

269 

270 

271class BinaryClassificationMetricPrecisionThreshold(BinaryClassificationMetric): 

272 """ 

273 Precision for the case where predictions are considered "positive" if predicted probability of the positive class is beyond the 

274 given threshold 

275 """ 

276 requires_probabilities = True 

277 

278 def __init__(self, threshold: float, positive_class_label: Any, zero_value=0.0): 

279 """ 

280 :param threshold: the minimum predicted probability of the positive class for the prediction to be considered "positive" 

281 :param zero_value: the value of the metric for the case where a positive class probability beyond the threshold is never predicted 

282 (denominator = 0) 

283 """ 

284 self.threshold = threshold 

285 self.zero_value = zero_value 

286 super().__init__(positive_class_label, name=f"precision[{threshold}]") 

287 

288 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

289 rel_freq_correct = RelativeFrequencyCounter() 

290 class_idx_positive = list(y_predicted_class_probabilities.columns).index(self.positiveClassLabel) 

291 for i, (probabilities, classLabel_true) in enumerate(zip(y_predicted_class_probabilities.values.tolist(), y_true)): 

292 prob_predicted = probabilities[class_idx_positive] 

293 if prob_predicted >= self.threshold: 

294 rel_freq_correct.count(classLabel_true == self.positiveClassLabel) 

295 f = rel_freq_correct.get_relative_frequency() 

296 return f if f is not None else self.zero_value 

297 

298 def get_paired_metrics(self) -> List[BinaryClassificationMetric]: 

299 return [BinaryClassificationMetricRecallThreshold(self.threshold, self.positiveClassLabel)] 

300 

301 

302class BinaryClassificationMetricRecallThreshold(BinaryClassificationMetric): 

303 """ 

304 Recall for the case where predictions are considered "positive" if predicted probability of the positive class is beyond the 

305 given threshold 

306 """ 

307 requires_probabilities = True 

308 

309 def __init__(self, threshold: float, positive_class_label: Any, zero_value=0.0): 

310 """ 

311 :param threshold: the minimum predicted probability of the positive class for the prediction to be considered "positive" 

312 :param zero_value: the value of the metric for the case where there are no positive instances in the data set (denominator = 0) 

313 """ 

314 self.threshold = threshold 

315 self.zero_value = zero_value 

316 super().__init__(positive_class_label, name=f"recall[{threshold}]") 

317 

318 def _compute_value(self, y_true: PredictionArray, y_predicted: PredictionArray, y_predicted_class_probabilities: PredictionArray): 

319 rel_freq_recalled = RelativeFrequencyCounter() 

320 class_idx_positive = list(y_predicted_class_probabilities.columns).index(self.positiveClassLabel) 

321 for i, (probabilities, classLabel_true) in enumerate(zip(y_predicted_class_probabilities.values.tolist(), y_true)): 

322 if self.positiveClassLabel == classLabel_true: 

323 prob_predicted = probabilities[class_idx_positive] 

324 rel_freq_recalled.count(prob_predicted >= self.threshold) 

325 f = rel_freq_recalled.get_relative_frequency() 

326 return f if f is not None else self.zero_value 

327 

328 

329DEFAULT_MULTICLASS_CLASSIFICATION_METRICS = (ClassificationMetricAccuracy(), ClassificationMetricBalancedAccuracy(), 

330 ClassificationMetricGeometricMeanOfTrueClassProbability()) 

331 

332 

333def create_default_binary_classification_metrics(positive_class_label: Any) -> List[BinaryClassificationMetric]: 

334 return [BinaryClassificationMetricPrecision(positive_class_label), BinaryClassificationMetricRecall(positive_class_label), 

335 BinaryClassificationMetricF1Score(positive_class_label)] 

336 

337 

338class ClassificationEvalStats(PredictionEvalStats["ClassificationMetric"]): 

339 def __init__(self, y_predicted: Optional[PredictionArray] = None, 

340 y_true: Optional[PredictionArray] = None, 

341 y_predicted_class_probabilities: Optional[pd.DataFrame] = None, 

342 labels: Optional[PredictionArray] = None, 

343 metrics: Optional[Sequence["ClassificationMetric"]] = None, 

344 additional_metrics: Optional[Sequence["ClassificationMetric"]] = None, 

345 binary_positive_label=GUESS): 

346 """ 

347 :param y_predicted: the predicted class labels 

348 :param y_true: the true class labels 

349 :param y_predicted_class_probabilities: a data frame whose columns are the class labels and whose values are probabilities 

350 :param labels: the list of class labels 

351 :param metrics: the metrics to compute for evaluation; if None, use default metrics 

352 (see DEFAULT_MULTICLASS_CLASSIFICATION_METRICS and :func:`create_default_binary_classification_metrics`) 

353 :param additional_metrics: the metrics to additionally compute 

354 :param binary_positive_label: the label of the positive class for the case where it is a binary classification, adding further 

355 binary metrics by default; 

356 if GUESS (default), check `labels` (if length 2) for occurrence of one of BINARY_CLASSIFICATION_POSITIVE_LABEL_CANDIDATES in 

357 the respective order and use the first one found (if any); 

358 if None, treat the problem as non-binary, regardless of the labels being used. 

359 """ 

360 self.labels = labels 

361 self.y_predicted_class_probabilities = y_predicted_class_probabilities 

362 self.is_probabilities_available = y_predicted_class_probabilities is not None 

363 if self.is_probabilities_available: 

364 col_set = set(y_predicted_class_probabilities.columns) 

365 if col_set != set(labels): 

366 raise ValueError(f"Columns in class probabilities data frame ({y_predicted_class_probabilities.columns}) do not " 

367 f"correspond to labels ({labels}") 

368 if len(y_predicted_class_probabilities) != len(y_true): 

369 raise ValueError("Row count in class probabilities data frame does not match ground truth") 

370 

371 num_labels = len(labels) 

372 if binary_positive_label == GUESS: 

373 found_candidate_label = False 

374 if num_labels == 2: 

375 for c in BINARY_CLASSIFICATION_POSITIVE_LABEL_CANDIDATES: 

376 if c in labels: 

377 binary_positive_label = c 

378 found_candidate_label = True 

379 break 

380 if not found_candidate_label: 

381 binary_positive_label = None 

382 elif binary_positive_label is not None: 

383 if num_labels != 2: 

384 log.warning(f"Passed binaryPositiveLabel for non-binary classification (labels={self.labels})") 

385 if binary_positive_label not in self.labels: 

386 log.warning(f"The binary positive label {binary_positive_label} does not appear in labels={labels}") 

387 if num_labels == 2 and binary_positive_label is None: 

388 log.warning(f"Binary classification (labels={labels}) without specification of positive class label; " 

389 f"binary classification metrics will not be considered") 

390 self.binary_positive_label = binary_positive_label 

391 self.is_binary = binary_positive_label is not None 

392 

393 if metrics is None: 

394 metrics = list(DEFAULT_MULTICLASS_CLASSIFICATION_METRICS) 

395 if self.is_binary: 

396 metrics.extend(create_default_binary_classification_metrics(self.binary_positive_label)) 

397 

398 metrics = list(metrics) 

399 if additional_metrics is not None: 

400 for m in additional_metrics: 

401 if not self.is_probabilities_available and m.requires_probabilities: 

402 raise ValueError(f"Additional metric {m} not supported, as class probabilities were not provided") 

403 

404 super().__init__(y_predicted, y_true, metrics, additional_metrics=additional_metrics) 

405 

406 # transient members 

407 self._binary_classification_probability_threshold_variation_data = None 

408 

409 def __getstate__(self): 

410 return getstate(ClassificationEvalStats, self, transient_properties=["_binaryClassificationProbabilityThresholdVariationData"]) 

411 

412 def get_confusion_matrix(self) -> "ConfusionMatrix": 

413 return ConfusionMatrix(self.y_true, self.y_predicted) 

414 

415 def get_binary_classification_probability_threshold_variation_data(self) -> "BinaryClassificationProbabilityThresholdVariationData": 

416 if self._binary_classification_probability_threshold_variation_data is None: 

417 self._binary_classification_probability_threshold_variation_data = BinaryClassificationProbabilityThresholdVariationData(self) 

418 return self._binary_classification_probability_threshold_variation_data 

419 

420 def get_accuracy(self): 

421 return self.compute_metric_value(ClassificationMetricAccuracy()) 

422 

423 def metrics_dict(self) -> Dict[str, float]: 

424 d = {} 

425 for metric in self.metrics: 

426 if not metric.requires_probabilities or self.is_probabilities_available: 

427 d[metric.name] = self.compute_metric_value(metric) 

428 return d 

429 

430 def get_misclassified_indices(self) -> List[int]: 

431 return [i for i, (predClass, trueClass) in enumerate(zip(self.y_predicted, self.y_true)) if predClass != trueClass] 

432 

433 def plot_confusion_matrix(self, normalize=True, title_add: str = None) -> plt.Figure: 

434 # based on https://scikit-learn.org/0.20/auto_examples/model_selection/plot_confusion_matrix.html 

435 cm = self.get_confusion_matrix() 

436 return cm.plot(normalize=normalize, title_add=title_add) 

437 

438 def plot_precision_recall_curve(self, title_add: str = None) -> plt.Figure: 

439 from sklearn.metrics import PrecisionRecallDisplay # only supported by newer versions of sklearn 

440 if not self.is_probabilities_available: 

441 raise Exception("Precision-recall curve requires probabilities") 

442 if not self.is_binary: 

443 raise Exception("Precision-recall curve is not applicable to non-binary classification") 

444 probabilities = self.y_predicted_class_probabilities[self.binary_positive_label] 

445 precision, recall, thresholds = precision_recall_curve(y_true=self.y_true, probas_pred=probabilities, 

446 pos_label=self.binary_positive_label) 

447 disp = PrecisionRecallDisplay(precision, recall) 

448 disp.plot() 

449 ax: plt.Axes = disp.ax_ 

450 ax.set_xlabel("recall") 

451 ax.set_ylabel("precision") 

452 title = "Precision-Recall Curve" 

453 if title_add is not None: 

454 title += "\n" + title_add 

455 ax.set_title(title) 

456 ax.xaxis.set_major_locator(plticker.MultipleLocator(base=0.1)) 

457 ax.yaxis.set_major_locator(plticker.MultipleLocator(base=0.1)) 

458 return disp.figure_ 

459 

460 

461class ClassificationEvalStatsCollection(EvalStatsCollection[ClassificationEvalStats, ClassificationMetric]): 

462 def __init__(self, eval_stats_list: List[ClassificationEvalStats]): 

463 super().__init__(eval_stats_list) 

464 self.globalStats = None 

465 

466 def get_combined_eval_stats(self) -> ClassificationEvalStats: 

467 """ 

468 Combines the data from all contained EvalStats objects into a single object. 

469 Note that this is only possible if all EvalStats objects use the same set of class labels. 

470 

471 :return: an EvalStats object that combines the data from all contained EvalStats objects 

472 """ 

473 if self.globalStats is None: 

474 y_true = np.concatenate([evalStats.y_true for evalStats in self.statsList]) 

475 y_predicted = np.concatenate([evalStats.y_predicted for evalStats in self.statsList]) 

476 es0 = self.statsList[0] 

477 if es0.y_predicted_class_probabilities is not None: 

478 y_probs = pd.concat([evalStats.y_predicted_class_probabilities for evalStats in self.statsList]) 

479 labels = list(y_probs.columns) 

480 else: 

481 y_probs = None 

482 labels = es0.labels 

483 self.globalStats = ClassificationEvalStats(y_predicted=y_predicted, y_true=y_true, y_predicted_class_probabilities=y_probs, 

484 labels=labels, binary_positive_label=es0.binary_positive_label, metrics=es0.metrics) 

485 return self.globalStats 

486 

487 

488class ConfusionMatrix: 

489 def __init__(self, y_true: PredictionArray, y_predicted: PredictionArray): 

490 self.labels = sklearn.utils.multiclass.unique_labels(y_true, y_predicted) 

491 self.confusionMatrix = confusion_matrix(y_true, y_predicted, labels=self.labels) 

492 

493 def plot(self, normalize: bool = True, title_add: str = None): 

494 title = 'Normalized Confusion Matrix' if normalize else 'Confusion Matrix (Counts)' 

495 return plot_matrix(self.confusionMatrix, title, self.labels, self.labels, 'true class', 'predicted class', normalize=normalize, 

496 title_add=title_add) 

497 

498 

499class BinaryClassificationCounts: 

500 def __init__(self, is_positive_prediction: Sequence[bool], is_positive_ground_truth: Sequence[bool], zero_denominator_metric_value: float = 0.): 

501 """ 

502 :param is_positive_prediction: the sequence of Booleans indicating whether the model predicted the positive class 

503 :param is_positive_ground_truth: the sequence of Booleans indicating whether the true class is the positive class 

504 :param zero_denominator_metric_value: the result to return for metrics such as precision and recall in case the denominator 

505 is zero (i.e. zero counted cases) 

506 """ 

507 self.zeroDenominatorMetricValue = zero_denominator_metric_value 

508 self.tp = 0 

509 self.tn = 0 

510 self.fp = 0 

511 self.fn = 0 

512 for predPositive, gtPositive in zip(is_positive_prediction, is_positive_ground_truth): 

513 if gtPositive: 

514 if predPositive: 

515 self.tp += 1 

516 else: 

517 self.fn += 1 

518 else: 

519 if predPositive: 

520 self.fp += 1 

521 else: 

522 self.tn += 1 

523 

524 @classmethod 

525 def from_probability_threshold(cls, probabilities: Sequence[float], threshold: float, is_positive_ground_truth: Sequence[bool]) \ 

526 -> "BinaryClassificationCounts": 

527 return cls([p >= threshold for p in probabilities], is_positive_ground_truth) 

528 

529 @classmethod 

530 def from_eval_stats(cls, eval_stats: ClassificationEvalStats, threshold=0.5) -> "BinaryClassificationCounts": 

531 if not eval_stats.is_binary: 

532 raise ValueError("Probability threshold variation data can only be computed for binary classification problems") 

533 if eval_stats.y_predicted_class_probabilities is None: 

534 raise ValueError("No probability data") 

535 pos_class_label = eval_stats.binary_positive_label 

536 probs = eval_stats.y_predicted_class_probabilities[pos_class_label] 

537 is_positive_gt = [gtLabel == pos_class_label for gtLabel in eval_stats.y_true] 

538 return cls.from_probability_threshold(probabilities=probs, threshold=threshold, is_positive_ground_truth=is_positive_gt) 

539 

540 def _frac(self, numerator, denominator): 

541 if denominator == 0: 

542 return self.zeroDenominatorMetricValue 

543 return numerator / denominator 

544 

545 def get_precision(self): 

546 return self._frac(self.tp, self.tp + self.fp) 

547 

548 def get_recall(self): 

549 return self._frac(self.tp, self.tp + self.fn) 

550 

551 def get_f1(self): 

552 return self._frac(self.tp, self.tp + 0.5 * (self.fp + self.fn)) 

553 

554 def get_rel_freq_positive(self): 

555 positive = self.tp + self.fp 

556 negative = self.tn + self.fn 

557 return positive / (positive + negative) 

558 

559 

560class BinaryClassificationProbabilityThresholdVariationData: 

561 def __init__(self, eval_stats: ClassificationEvalStats): 

562 self.thresholds = np.linspace(0, 1, 101) 

563 self.counts: List[BinaryClassificationCounts] = [] 

564 for threshold in self.thresholds: 

565 self.counts.append(BinaryClassificationCounts.from_eval_stats(eval_stats, threshold=threshold)) 

566 

567 def plot_precision_recall(self, subtitle=None) -> plt.Figure: 

568 fig = plt.figure() 

569 title = "Probability Threshold-Dependent Precision & Recall" 

570 if subtitle is not None: 

571 title += "\n" + subtitle 

572 plt.title(title) 

573 plt.xlabel("probability threshold") 

574 precision = [c.get_precision() for c in self.counts] 

575 recall = [c.get_recall() for c in self.counts] 

576 f1 = [c.get_f1() for c in self.counts] 

577 rf_positive = [c.get_rel_freq_positive() for c in self.counts] 

578 plt.plot(self.thresholds, precision, label="precision") 

579 plt.plot(self.thresholds, recall, label="recall") 

580 plt.plot(self.thresholds, f1, label="F1-score") 

581 plt.plot(self.thresholds, rf_positive, label="rel. freq. positive") 

582 plt.legend() 

583 return fig 

584 

585 def plot_counts(self, subtitle=None): 

586 fig = plt.figure() 

587 title = "Probability Threshold-Dependent Counts" 

588 if subtitle is not None: 

589 title += "\n" + subtitle 

590 plt.title(title) 

591 plt.xlabel("probability threshold") 

592 plt.stackplot(self.thresholds, 

593 [c.tp for c in self.counts], [c.tn for c in self.counts], [c.fp for c in self.counts], [c.fn for c in self.counts], 

594 labels=["true positives", "true negatives", "false positives", "false negatives"], 

595 colors=["#4fa244", "#79c36f", "#a25344", "#c37d6f"]) 

596 plt.legend() 

597 return fig 

598 

599 

600class ClassificationEvalStatsPlot(EvalStatsPlot[ClassificationEvalStats], ABC): 

601 pass 

602 

603 

604class ClassificationEvalStatsPlotConfusionMatrix(ClassificationEvalStatsPlot): 

605 def __init__(self, normalise=True): 

606 self.normalise = normalise 

607 

608 def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> plt.Figure: 

609 return eval_stats.plot_confusion_matrix(normalize=self.normalise, title_add=subtitle) 

610 

611 

612class ClassificationEvalStatsPlotPrecisionRecall(ClassificationEvalStatsPlot): 

613 def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]: 

614 if not eval_stats.is_binary or not eval_stats.is_probabilities_available: 

615 return None 

616 return eval_stats.plot_precision_recall_curve(title_add=subtitle) 

617 

618 

619class ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall(ClassificationEvalStatsPlot): 

620 def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]: 

621 if not eval_stats.is_binary or not eval_stats.is_probabilities_available: 

622 return None 

623 return eval_stats.get_binary_classification_probability_threshold_variation_data().plot_precision_recall(subtitle=subtitle) 

624 

625 

626class ClassificationEvalStatsPlotProbabilityThresholdCounts(ClassificationEvalStatsPlot): 

627 def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]: 

628 if not eval_stats.is_binary or not eval_stats.is_probabilities_available: 

629 return None 

630 return eval_stats.get_binary_classification_probability_threshold_variation_data().plot_counts(subtitle=subtitle)