Coverage for src/sensai/evaluation/eval_util.py: 21%

535 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1""" 

2This module contains methods and classes that facilitate evaluation of different types of models. The suggested 

3workflow for evaluation is to use these higher-level functionalities instead of instantiating 

4the evaluation classes directly. 

5""" 

6import functools 

7import logging 

8from abc import ABC, abstractmethod 

9from collections import defaultdict 

10from dataclasses import dataclass 

11from typing import Dict, Any, Union, Generic, TypeVar, Optional, Sequence, Callable, Set, Iterable, List, Iterator, Tuple 

12 

13import matplotlib.figure 

14import matplotlib.pyplot as plt 

15import numpy as np 

16import pandas as pd 

17import seaborn as sns 

18 

19from .crossval import VectorModelCrossValidationData, VectorRegressionModelCrossValidationData, \ 

20 VectorClassificationModelCrossValidationData, \ 

21 VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator, VectorModelCrossValidator, VectorModelCrossValidatorParams 

22from .eval_stats import RegressionEvalStatsCollection, ClassificationEvalStatsCollection, RegressionEvalStatsPlotErrorDistribution, \ 

23 RegressionEvalStatsPlotHeatmapGroundTruthPredictions, RegressionEvalStatsPlotScatterGroundTruthPredictions, \ 

24 ClassificationEvalStatsPlotConfusionMatrix, ClassificationEvalStatsPlotPrecisionRecall, RegressionEvalStatsPlot, \ 

25 ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall, ClassificationEvalStatsPlotProbabilityThresholdCounts, \ 

26 Metric 

27from .eval_stats.eval_stats_base import EvalStats, EvalStatsCollection, EvalStatsPlot 

28from .eval_stats.eval_stats_classification import ClassificationEvalStats 

29from .eval_stats.eval_stats_regression import RegressionEvalStats 

30from .evaluator import VectorModelEvaluator, VectorModelEvaluationData, VectorRegressionModelEvaluator, \ 

31 VectorRegressionModelEvaluationData, VectorClassificationModelEvaluator, VectorClassificationModelEvaluationData, \ 

32 RegressionEvaluatorParams, ClassificationEvaluatorParams 

33from ..data import InputOutputData 

34from ..feature_importance import AggregatedFeatureImportance, FeatureImportanceProvider, plot_feature_importance, FeatureImportance 

35from ..tracking import TrackedExperiment 

36from ..tracking.tracking_base import TrackingContext 

37from ..util.deprecation import deprecated 

38from ..util.io import ResultWriter 

39from ..util.string import pretty_string_repr 

40from ..vector_model import VectorClassificationModel, VectorRegressionModel, VectorModel, VectorModelBase 

41 

42log = logging.getLogger(__name__) 

43 

44TModel = TypeVar("TModel", bound=VectorModel) 

45TEvalStats = TypeVar("TEvalStats", bound=EvalStats) 

46TEvalStatsPlot = TypeVar("TEvalStatsPlot", bound=EvalStatsPlot) 

47TEvalStatsCollection = TypeVar("TEvalStatsCollection", bound=EvalStatsCollection) 

48TEvaluator = TypeVar("TEvaluator", bound=VectorModelEvaluator) 

49TCrossValidator = TypeVar("TCrossValidator", bound=VectorModelCrossValidator) 

50TEvalData = TypeVar("TEvalData", bound=VectorModelEvaluationData) 

51TCrossValData = TypeVar("TCrossValData", bound=VectorModelCrossValidationData) 

52 

53 

54def _is_regression(model: Optional[VectorModel], is_regression: Optional[bool]) -> bool: 

55 if model is None and is_regression is None or (model is not None and is_regression is not None): 

56 raise ValueError("One of the two parameters have to be passed: model or isRegression") 

57 

58 if is_regression is None: 

59 model: VectorModel 

60 return model.is_regression_model() 

61 return is_regression 

62 

63 

64def create_vector_model_evaluator(data: InputOutputData, model: VectorModel = None, 

65 is_regression: bool = None, params: Union[RegressionEvaluatorParams, ClassificationEvaluatorParams] = None, 

66 test_data: Optional[InputOutputData] = None) \ 

67 -> Union[VectorRegressionModelEvaluator, VectorClassificationModelEvaluator]: 

68 is_regression = _is_regression(model, is_regression) 

69 if params is None: 

70 if is_regression: 

71 params = RegressionEvaluatorParams(fractional_split_test_fraction=0.2) 

72 else: 

73 params = ClassificationEvaluatorParams(fractional_split_test_fraction=0.2) 

74 log.debug(f"No evaluator parameters specified, using default: {params}") 

75 if is_regression: 

76 return VectorRegressionModelEvaluator(data, test_data=test_data, params=params) 

77 else: 

78 return VectorClassificationModelEvaluator(data, test_data=test_data, params=params) 

79 

80 

81def create_vector_model_cross_validator(data: InputOutputData, 

82 model: VectorModel = None, 

83 is_regression: bool = None, 

84 params: Union[VectorModelCrossValidatorParams, Dict[str, Any]] = None) \ 

85 -> Union[VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator]: 

86 if params is None: 

87 raise ValueError("params must not be None") 

88 cons = VectorRegressionModelCrossValidator if _is_regression(model, is_regression) else VectorClassificationModelCrossValidator 

89 return cons(data, params=params) 

90 

91 

92def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None, 

93 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None, 

94 cross_validator_params: Optional[Dict[str, Any]] = None, test_io_data: Optional[InputOutputData] = None) \ 

95 -> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]: 

96 if _is_regression(model, is_regression): 

97 return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data) 

98 else: 

99 return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data) 

100 

101 

102def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2, 

103 plot_target_distribution=False, compute_probabilities=True, normalize_plots=True, random_seed=60) -> TEvalData: 

104 """ 

105 Evaluates the given model via a simple evaluation mechanism that uses a single split 

106 

107 :param model: the model to evaluate 

108 :param io_data: data on which to evaluate 

109 :param test_fraction: the fraction of the data to test on 

110 :param plot_target_distribution: whether to plot the target values distribution in the entire dataset 

111 :param compute_probabilities: only relevant if the model is a classifier 

112 :param normalize_plots: whether to normalize plotted distributions such that the sum/integrate to 1 

113 :param random_seed: 

114 

115 :return: the evaluation data 

116 """ 

117 if plot_target_distribution: 

118 title = "Distribution of target values in entire dataset" 

119 fig = plt.figure(title) 

120 

121 output_distribution_series = io_data.outputs.iloc[:, 0] 

122 log.info(f"Description of target column in training set: \n{output_distribution_series.describe()}") 

123 if not model.is_regression_model(): 

124 output_distribution_series = output_distribution_series.value_counts(normalize=normalize_plots) 

125 ax = sns.barplot(output_distribution_series.index, output_distribution_series.values) 

126 ax.set_ylabel("%") 

127 else: 

128 ax = sns.distplot(output_distribution_series) 

129 ax.set_ylabel("Probability density") 

130 ax.set_title(title) 

131 ax.set_xlabel("target value") 

132 fig.show() 

133 

134 if model.is_regression_model(): 

135 evaluator_params = RegressionEvaluatorParams(fractional_split_test_fraction=test_fraction, 

136 fractional_split_random_seed=random_seed) 

137 else: 

138 evaluator_params = ClassificationEvaluatorParams(fractional_split_test_fraction=test_fraction, 

139 compute_probabilities=compute_probabilities, fractional_split_random_seed=random_seed) 

140 ev = create_evaluation_util(io_data, model=model, evaluator_params=evaluator_params) 

141 return ev.perform_simple_evaluation(model, show_plots=True, log_results=True) 

142 

143 

144class EvaluationResultCollector: 

145 def __init__(self, show_plots: bool = True, result_writer: Optional[ResultWriter] = None, 

146 tracking_context: TrackingContext = None): 

147 self.show_plots = show_plots 

148 self.result_writer = result_writer 

149 self.tracking_context = tracking_context 

150 

151 def is_plot_creation_enabled(self) -> bool: 

152 return self.show_plots or self.result_writer is not None or self.tracking_context is not None 

153 

154 def add_figure(self, name: str, fig: matplotlib.figure.Figure): 

155 if self.result_writer is not None: 

156 self.result_writer.write_figure(name, fig, close_figure=False) 

157 if self.tracking_context is not None: 

158 self.tracking_context.track_figure(name, fig) 

159 if not self.show_plots: 

160 plt.close(fig) 

161 

162 def add_data_frame_csv_file(self, name: str, df: pd.DataFrame): 

163 if self.result_writer is not None: 

164 self.result_writer.write_data_frame_csv_file(name, df) 

165 

166 def child(self, added_filename_prefix): 

167 result_writer = self.result_writer 

168 if result_writer: 

169 result_writer = result_writer.child_with_added_prefix(added_filename_prefix) 

170 return self.__class__(show_plots=self.show_plots, result_writer=result_writer) 

171 

172 

173class EvalStatsPlotCollector(Generic[TEvalStats, TEvalStatsPlot]): 

174 def __init__(self): 

175 self.plots: Dict[str, EvalStatsPlot] = {} 

176 self.disabled_plots: Set[str] = set() 

177 

178 def add_plot(self, name: str, plot: EvalStatsPlot): 

179 self.plots[name] = plot 

180 

181 def get_enabled_plots(self) -> List[str]: 

182 return [p for p in self.plots if p not in self.disabled_plots] 

183 

184 def disable_plots(self, *names: str): 

185 self.disabled_plots.update(names) 

186 

187 def create_plots(self, eval_stats: EvalStats, subtitle: str, result_collector: EvaluationResultCollector): 

188 known_plots = set(self.plots.keys()) 

189 unknown_disabled_plots = self.disabled_plots.difference(known_plots) 

190 if len(unknown_disabled_plots) > 0: 

191 log.warning(f"Plots were disabled which are not registered: {unknown_disabled_plots}; known plots: {known_plots}") 

192 for name, plot in self.plots.items(): 

193 if name not in self.disabled_plots and plot.is_applicable(eval_stats): 

194 fig = plot.create_figure(eval_stats, subtitle) 

195 if fig is not None: 

196 result_collector.add_figure(name, fig) 

197 

198 

199class RegressionEvalStatsPlotCollector(EvalStatsPlotCollector[RegressionEvalStats, RegressionEvalStatsPlot]): 

200 def __init__(self): 

201 super().__init__() 

202 self.add_plot("error-dist", RegressionEvalStatsPlotErrorDistribution()) 

203 self.add_plot("heatmap-gt-pred", RegressionEvalStatsPlotHeatmapGroundTruthPredictions(weighted=False)) 

204 self.add_plot("heatmap-gt-pred-weighted", RegressionEvalStatsPlotHeatmapGroundTruthPredictions(weighted=True)) 

205 self.add_plot("scatter-gt-pred", RegressionEvalStatsPlotScatterGroundTruthPredictions()) 

206 

207 

208class ClassificationEvalStatsPlotCollector(EvalStatsPlotCollector[RegressionEvalStats, RegressionEvalStatsPlot]): 

209 def __init__(self): 

210 super().__init__() 

211 self.add_plot("confusion-matrix-rel", ClassificationEvalStatsPlotConfusionMatrix(normalise=True)) 

212 self.add_plot("confusion-matrix-abs", ClassificationEvalStatsPlotConfusionMatrix(normalise=False)) 

213 # the plots below apply to the binary case only (skipped for non-binary case) 

214 self.add_plot("precision-recall", ClassificationEvalStatsPlotPrecisionRecall()) 

215 self.add_plot("threshold-precision-recall", ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall()) 

216 self.add_plot("threshold-counts", ClassificationEvalStatsPlotProbabilityThresholdCounts()) 

217 

218 

219class ModelEvaluation(ABC, Generic[TModel, TEvaluator, TEvalData, TCrossValidator, TCrossValData, TEvalStats]): 

220 """ 

221 Utility class for the evaluation of models based on a dataset 

222 """ 

223 def __init__(self, io_data: InputOutputData, 

224 eval_stats_plot_collector: Union[RegressionEvalStatsPlotCollector, ClassificationEvalStatsPlotCollector], 

225 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, 

226 Dict[str, Any]]] = None, 

227 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, 

228 test_io_data: Optional[InputOutputData] = None): 

229 """ 

230 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split 

231 into training and test data according to the rules specified by `evaluator_params`. 

232 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is 

233 taken to be the test data when creating evaluators for simple (single-split) evaluation. 

234 :param eval_stats_plot_collector: a collector for plots generated from evaluation stats objects 

235 :param evaluator_params: parameters with which to instantiate evaluators 

236 :param cross_validator_params: parameters with which to instantiate cross-validators 

237 :param test_io_data: optional test data (see `io_data`) 

238 """ 

239 if cross_validator_params is None: 

240 cross_validator_params = VectorModelCrossValidatorParams(folds=5) 

241 self.evaluator_params = evaluator_params 

242 self.cross_validator_params = cross_validator_params 

243 self.io_data = io_data 

244 self.test_io_data = test_io_data 

245 self.eval_stats_plot_collector = eval_stats_plot_collector 

246 

247 def create_evaluator(self, model: TModel = None, is_regression: bool = None) -> TEvaluator: 

248 """ 

249 Creates an evaluator holding the current input-output data 

250 

251 :param model: the model for which to create an evaluator (just for reading off regression or classification, 

252 the resulting evaluator will work on other models as well) 

253 :param is_regression: whether to create a regression model evaluator. Either this or model have to be specified 

254 :return: an evaluator 

255 """ 

256 return create_vector_model_evaluator(self.io_data, model=model, is_regression=is_regression, test_data=self.test_io_data, 

257 params=self.evaluator_params) 

258 

259 def create_cross_validator(self, model: TModel = None, is_regression: bool = None) -> TCrossValidator: 

260 """ 

261 Creates a cross-validator holding the current input-output data 

262 

263 :param model: the model for which to create a cross-validator (just for reading off regression or classification, 

264 the resulting evaluator will work on other models as well) 

265 :param is_regression: whether to create a regression model cross-validator. Either this or model have to be specified 

266 :return: an evaluator 

267 """ 

268 return create_vector_model_cross_validator(self.io_data, model=model, is_regression=is_regression, 

269 params=self.cross_validator_params) 

270 

271 def perform_simple_evaluation(self, model: TModel, 

272 create_plots=True, show_plots=False, 

273 log_results=True, 

274 result_writer: ResultWriter = None, 

275 additional_evaluation_on_training_data=False, 

276 fit_model=True, write_eval_stats=False, 

277 tracked_experiment: TrackedExperiment = None, 

278 evaluator: Optional[TEvaluator] = None) -> TEvalData: 

279 

280 if show_plots and not create_plots: 

281 raise ValueError("showPlots=True requires createPlots=True") 

282 result_writer = self._result_writer_for_model(result_writer, model) 

283 if evaluator is None: 

284 evaluator = self.create_evaluator(model) 

285 if tracked_experiment is not None: 

286 evaluator.set_tracked_experiment(tracked_experiment) 

287 log.info(f"Evaluating {model} via {evaluator}") 

288 

289 def gather_results(result_data: VectorModelEvaluationData, res_writer, subtitle_prefix=""): 

290 str_eval_results = "" 

291 for predictedVarName in result_data.predicted_var_names: 

292 eval_stats = result_data.get_eval_stats(predictedVarName) 

293 str_eval_result = str(eval_stats) 

294 if log_results: 

295 log.info(f"{subtitle_prefix}Evaluation results for {predictedVarName}: {str_eval_result}") 

296 str_eval_results += predictedVarName + ": " + str_eval_result + "\n" 

297 if write_eval_stats and res_writer is not None: 

298 res_writer.write_pickle(f"eval-stats-{predictedVarName}", eval_stats) 

299 str_eval_results += f"\n\n{pretty_string_repr(model)}" 

300 if res_writer is not None: 

301 res_writer.write_text_file("evaluator-results", str_eval_results) 

302 if create_plots: 

303 with TrackingContext.from_optional_experiment(tracked_experiment, model=model) as trackingContext: 

304 self.create_plots(result_data, show_plots=show_plots, result_writer=res_writer, 

305 subtitle_prefix=subtitle_prefix, tracking_context=trackingContext) 

306 

307 eval_result_data = evaluator.eval_model(model, fit=fit_model) 

308 gather_results(eval_result_data, result_writer) 

309 if additional_evaluation_on_training_data: 

310 eval_result_data_train = evaluator.eval_model(model, on_training_data=True, track=False) 

311 additional_result_writer = result_writer.child_with_added_prefix("onTrain-") if result_writer is not None else None 

312 gather_results(eval_result_data_train, additional_result_writer, subtitle_prefix="[onTrain] ") 

313 return eval_result_data 

314 

315 @staticmethod 

316 def _result_writer_for_model(result_writer: Optional[ResultWriter], model: TModel) -> Optional[ResultWriter]: 

317 if result_writer is None: 

318 return None 

319 return result_writer.child_with_added_prefix(model.get_name() + "_") 

320 

321 def perform_cross_validation(self, model: TModel, show_plots=False, log_results=True, result_writer: Optional[ResultWriter] = None, 

322 tracked_experiment: TrackedExperiment = None, cross_validator: Optional[TCrossValidator] = None) -> TCrossValData: 

323 """ 

324 Evaluates the given model via cross-validation 

325 

326 :param model: the model to evaluate 

327 :param show_plots: whether to show plots that visualise evaluation results (combining all folds) 

328 :param log_results: whether to log evaluation results 

329 :param result_writer: a writer with which to store text files and plots. The evaluated model's name is added to each filename 

330 automatically 

331 :param tracked_experiment: a tracked experiment with which results shall be associated 

332 :return: cross-validation result data 

333 :param cross_validator: the cross-validator to apply; if None, a suitable cross-validator will be created 

334 """ 

335 result_writer = self._result_writer_for_model(result_writer, model) 

336 

337 if cross_validator is None: 

338 cross_validator = self.create_cross_validator(model) 

339 if tracked_experiment is not None: 

340 cross_validator.set_tracked_experiment(tracked_experiment) 

341 

342 cross_validation_data = cross_validator.eval_model(model) 

343 

344 agg_stats_by_var = {varName: cross_validation_data.get_eval_stats_collection(predicted_var_name=varName).agg_metrics_dict() 

345 for varName in cross_validation_data.predicted_var_names} 

346 df = pd.DataFrame.from_dict(agg_stats_by_var, orient="index") 

347 

348 str_eval_results = df.to_string() 

349 if log_results: 

350 log.info(f"Cross-validation results:\n{str_eval_results}") 

351 if result_writer is not None: 

352 result_writer.write_text_file("crossval-results", str_eval_results) 

353 

354 with TrackingContext.from_optional_experiment(tracked_experiment, model=model) as trackingContext: 

355 self.create_plots(cross_validation_data, show_plots=show_plots, result_writer=result_writer, 

356 tracking_context=trackingContext) 

357 

358 return cross_validation_data 

359 

360 def compare_models(self, models: Sequence[TModel], result_writer: Optional[ResultWriter] = None, use_cross_validation=False, 

361 fit_models=True, write_individual_results=True, sort_column: Optional[str] = None, sort_ascending: bool = True, 

362 sort_column_move_to_left=True, 

363 also_include_unsorted_results: bool = False, also_include_cross_val_global_stats: bool = False, 

364 visitors: Optional[Iterable["ModelComparisonVisitor"]] = None, 

365 write_visitor_results=False, write_csv=False, 

366 tracked_experiment: Optional[TrackedExperiment] = None) -> "ModelComparisonData": 

367 """ 

368 Compares several models via simple evaluation or cross-validation 

369 

370 :param models: the models to compare 

371 :param result_writer: a writer with which to store results of the comparison 

372 :param use_cross_validation: whether to use cross-validation in order to evaluate models; if False, use a simple evaluation 

373 on test data (single split) 

374 :param fit_models: whether to fit models before evaluating them; this can only be False if useCrossValidation=False 

375 :param write_individual_results: whether to write results files on each individual model (in addition to the comparison 

376 summary) 

377 :param sort_column: column/metric name by which to sort; the fact that the column names change when using cross-validation 

378 (aggregation function names being added) should be ignored, simply pass the (unmodified) metric name 

379 :param sort_ascending: whether to sort using `sortColumn` in ascending order 

380 :param sort_column_move_to_left: whether to move the `sortColumn` (if any) to the very left 

381 :param also_include_unsorted_results: whether to also include, for the case where the results are sorted, the unsorted table of 

382 results in the results text 

383 :param also_include_cross_val_global_stats: whether to also include, when using cross-validation, the evaluation metrics obtained 

384 when combining the predictions from all folds into a single collection. Note that for classification models, 

385 this may not always be possible (if the set of classes know to the model differs across folds) 

386 :param visitors: visitors which may process individual results 

387 :param write_visitor_results: whether to collect results from visitors (if any) after the comparison 

388 :param write_csv: whether to write metrics table to CSV files 

389 :param tracked_experiment: an experiment for tracking 

390 :return: the comparison results 

391 """ 

392 # collect model evaluation results 

393 stats_list = [] 

394 result_by_model_name = {} 

395 evaluator = None 

396 cross_validator = None 

397 for i, model in enumerate(models, start=1): 

398 model_name = model.get_name() 

399 log.info(f"Evaluating model {i}/{len(models)} named '{model_name}' ...") 

400 if use_cross_validation: 

401 if not fit_models: 

402 raise ValueError("Cross-validation necessitates that models be trained several times; got fitModels=False") 

403 if cross_validator is None: 

404 cross_validator = self.create_cross_validator(model) 

405 cross_val_data = self.perform_cross_validation(model, result_writer=result_writer if write_individual_results else None, 

406 cross_validator=cross_validator, tracked_experiment=tracked_experiment) 

407 model_result = ModelComparisonData.Result(cross_validation_data=cross_val_data) 

408 result_by_model_name[model_name] = model_result 

409 eval_stats_collection = cross_val_data.get_eval_stats_collection() 

410 stats_dict = eval_stats_collection.agg_metrics_dict() 

411 else: 

412 if evaluator is None: 

413 evaluator = self.create_evaluator(model) 

414 eval_data = self.perform_simple_evaluation(model, result_writer=result_writer if write_individual_results else None, 

415 fit_model=fit_models, evaluator=evaluator, tracked_experiment=tracked_experiment) 

416 model_result = ModelComparisonData.Result(eval_data=eval_data) 

417 result_by_model_name[model_name] = model_result 

418 eval_stats = eval_data.get_eval_stats() 

419 stats_dict = eval_stats.metrics_dict() 

420 stats_dict["model_name"] = model_name 

421 stats_list.append(stats_dict) 

422 if visitors is not None: 

423 for visitor in visitors: 

424 visitor.visit(model_name, model_result) 

425 results_df = pd.DataFrame(stats_list).set_index("model_name") 

426 

427 # compute results data frame with combined set of data points (for cross-validation only) 

428 cross_val_combined_results_df = None 

429 if use_cross_validation and also_include_cross_val_global_stats: 

430 try: 

431 rows = [] 

432 for model_name, result in result_by_model_name.items(): 

433 stats_dict = result.cross_validation_data.get_eval_stats_collection().get_global_stats().metrics_dict() 

434 stats_dict["model_name"] = model_name 

435 rows.append(stats_dict) 

436 cross_val_combined_results_df = pd.DataFrame(rows).set_index("model_name") 

437 except Exception as e: 

438 log.error(f"Creation of global stats data frame from cross-validation folds failed: {e}") 

439 

440 def sorted_df(df, sort_col): 

441 if sort_col is not None: 

442 if sort_col not in df.columns: 

443 alt_sort_col = f"mean[{sort_col}]" 

444 if alt_sort_col in df.columns: 

445 sort_col = alt_sort_col 

446 else: 

447 sort_col = None 

448 log.warning(f"Requested sort column '{sort_col}' (or '{alt_sort_col}') not in list of columns {list(df.columns)}") 

449 if sort_col is not None: 

450 df = df.sort_values(sort_col, ascending=sort_ascending, inplace=False) 

451 if sort_column_move_to_left: 

452 df = df[[sort_col] + [c for c in df.columns if c != sort_col]] 

453 return df 

454 

455 # write comparison results 

456 title = "Model comparison results" 

457 if use_cross_validation: 

458 title += ", aggregated across folds" 

459 sorted_results_df = sorted_df(results_df, sort_column) 

460 str_results = f"{title}:\n{sorted_results_df.to_string()}" 

461 if also_include_unsorted_results and sort_column is not None: 

462 str_results += f"\n\n{title} (unsorted):\n{results_df.to_string()}" 

463 sorted_cross_val_combined_results_df = None 

464 if cross_val_combined_results_df is not None: 

465 sorted_cross_val_combined_results_df = sorted_df(cross_val_combined_results_df, sort_column) 

466 str_results += f"\n\nModel comparison results based on combined set of data points from all folds:\n" \ 

467 f"{sorted_cross_val_combined_results_df.to_string()}" 

468 log.info(str_results) 

469 if result_writer is not None: 

470 suffix = "crossval" if use_cross_validation else "simple-eval" 

471 str_results += "\n\n" + "\n\n".join([f"{model.get_name()} = {model.pprints()}" for model in models]) 

472 result_writer.write_text_file(f"model-comparison-results-{suffix}", str_results) 

473 if write_csv: 

474 result_writer.write_data_frame_csv_file(f"model-comparison-metrics-{suffix}", sorted_results_df) 

475 if sorted_cross_val_combined_results_df is not None: 

476 result_writer.write_data_frame_csv_file(f"model-comparison-metrics-{suffix}-combined", 

477 sorted_cross_val_combined_results_df) 

478 

479 # write visitor results 

480 if visitors is not None and write_visitor_results: 

481 result_collector = EvaluationResultCollector(show_plots=False, result_writer=result_writer) 

482 for visitor in visitors: 

483 visitor.collect_results(result_collector) 

484 

485 return ModelComparisonData(results_df, result_by_model_name, evaluator=evaluator, cross_validator=cross_validator) 

486 

487 def compare_models_cross_validation(self, models: Sequence[TModel], 

488 result_writer: Optional[ResultWriter] = None) -> "ModelComparisonData": 

489 """ 

490 Compares several models via cross-validation 

491 

492 :param models: the models to compare 

493 :param result_writer: a writer with which to store results of the comparison 

494 :return: the comparison results 

495 """ 

496 return self.compare_models(models, result_writer=result_writer, use_cross_validation=True) 

497 

498 def create_plots(self, data: Union[TEvalData, TCrossValData], show_plots=True, result_writer: Optional[ResultWriter] = None, 

499 subtitle_prefix: str = "", tracking_context: Optional[TrackingContext] = None): 

500 """ 

501 Creates default plots that visualise the results in the given evaluation data 

502 

503 :param data: the evaluation data for which to create the default plots 

504 :param show_plots: whether to show plots 

505 :param result_writer: if not None, plots will be written using this writer 

506 :param subtitle_prefix: a prefix to add to the subtitle (which itself is the model name) 

507 :param tracking_context: the experiment tracking context 

508 """ 

509 result_collector = EvaluationResultCollector(show_plots=show_plots, result_writer=result_writer, 

510 tracking_context=tracking_context) 

511 if result_collector.is_plot_creation_enabled(): 

512 self._create_plots(data, result_collector, subtitle=subtitle_prefix + data.model_name) 

513 

514 def _create_plots(self, data: Union[TEvalData, TCrossValData], result_collector: EvaluationResultCollector, subtitle=None): 

515 

516 def create_plots(pred_var_name, res_collector, subt): 

517 if isinstance(data, VectorModelCrossValidationData): 

518 eval_stats = data.get_eval_stats_collection(predicted_var_name=pred_var_name).get_global_stats() 

519 elif isinstance(data, VectorModelEvaluationData): 

520 eval_stats = data.get_eval_stats(predicted_var_name=pred_var_name) 

521 else: 

522 raise ValueError(f"Unexpected argument: data={data}") 

523 return self._create_eval_stats_plots(eval_stats, res_collector, subtitle=subt) 

524 

525 predicted_var_names = data.predicted_var_names 

526 if len(predicted_var_names) == 1: 

527 create_plots(predicted_var_names[0], result_collector, subtitle) 

528 else: 

529 for predictedVarName in predicted_var_names: 

530 create_plots(predictedVarName, result_collector.child(predictedVarName + "-"), f"{predictedVarName}, {subtitle}") 

531 

532 def _create_eval_stats_plots(self, eval_stats: TEvalStats, result_collector: EvaluationResultCollector, subtitle=None): 

533 """ 

534 :param eval_stats: the evaluation results for which to create plots 

535 :param result_collector: the collector to which all plots are to be passed 

536 :param subtitle: the subtitle to use for generated plots (if any) 

537 """ 

538 self.eval_stats_plot_collector.create_plots(eval_stats, subtitle, result_collector) 

539 

540 

541class RegressionModelEvaluation(ModelEvaluation[VectorRegressionModel, VectorRegressionModelEvaluator, VectorRegressionModelEvaluationData, 

542 VectorRegressionModelCrossValidator, VectorRegressionModelCrossValidationData, RegressionEvalStats]): 

543 def __init__(self, io_data: InputOutputData, 

544 evaluator_params: Optional[Union[RegressionEvaluatorParams, Dict[str, Any]]] = None, 

545 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, 

546 test_io_data: Optional[InputOutputData] = None): 

547 """ 

548 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split 

549 into training and test data according to the rules specified by `evaluator_params`. 

550 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is 

551 taken to be the test data when creating evaluators for simple (single-split) evaluation. 

552 :param evaluator_params: parameters with which to instantiate evaluators 

553 :param cross_validator_params: parameters with which to instantiate cross-validators 

554 :param test_io_data: optional test data (see `io_data`) 

555 """ 

556 super().__init__(io_data, eval_stats_plot_collector=RegressionEvalStatsPlotCollector(), evaluator_params=evaluator_params, 

557 cross_validator_params=cross_validator_params, test_io_data=test_io_data) 

558 

559 

560class ClassificationModelEvaluation(ModelEvaluation[VectorClassificationModel, VectorClassificationModelEvaluator, 

561 VectorClassificationModelEvaluationData, VectorClassificationModelCrossValidator, VectorClassificationModelCrossValidationData, 

562 ClassificationEvalStats]): 

563 def __init__(self, io_data: InputOutputData, 

564 evaluator_params: Optional[Union[ClassificationEvaluatorParams, Dict[str, Any]]] = None, 

565 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, 

566 test_io_data: Optional[InputOutputData] = None): 

567 """ 

568 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split 

569 into training and test data according to the rules specified by `evaluator_params`. 

570 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is 

571 taken to be the test data when creating evaluators for simple (single-split) evaluation. 

572 :param evaluator_params: parameters with which to instantiate evaluators 

573 :param cross_validator_params: parameters with which to instantiate cross-validators 

574 :param test_io_data: optional test data (see `io_data`) 

575 """ 

576 super().__init__(io_data, eval_stats_plot_collector=ClassificationEvalStatsPlotCollector(), evaluator_params=evaluator_params, 

577 cross_validator_params=cross_validator_params, test_io_data=test_io_data) 

578 

579 

580class MultiDataModelEvaluation: 

581 def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset", 

582 meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None, 

583 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None, 

584 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, 

585 test_io_data_dict: Optional[Dict[str, Optional[InputOutputData]]] = None): 

586 """ 

587 :param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models. 

588 For evaluation or cross-validation, these datasets will usually be split according to the rules 

589 specified by `evaluator_params or `cross_validator_params`. An exception is the case where 

590 explicit test data sets are specified by passing `test_io_data_dict`. Then, for these data 

591 sets, the io_data will not be split for evaluation, but the test_io_data will be used instead. 

592 :param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames 

593 :param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps 

594 from a column name to a value and which is to be used to extend the result data frames containing per-dataset results 

595 :param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False) 

596 :param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True) 

597 :param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation or to None. 

598 Entries with non-None values will be used for evaluation of the models that were trained on the respective io_data_dict. 

599 If passed, the keys need to be a superset of io_data_dict's keys (note that the values may be None, e.g. 

600 if you want to use test data sets for some entries, and splitting of the io_data for others). 

601 If not None, cross-validation cannot be used when calling ``compare_models``. 

602 """ 

603 if test_io_data_dict is not None: 

604 missing_keys = set(io_data_dict).difference(test_io_data_dict) 

605 if len(missing_keys) > 0: 

606 raise ValueError( 

607 "If test_io_data_dict is passed, its keys must be a superset of the io_data_dict's keys." 

608 f"However, found missing_keys: {missing_keys}") 

609 self.io_data_dict = io_data_dict 

610 self.test_io_data_dict = test_io_data_dict 

611 

612 self.key_name = key_name 

613 self.evaluator_params = evaluator_params 

614 self.cross_validator_params = cross_validator_params 

615 if meta_data_dict is not None: 

616 self.meta_df = pd.DataFrame(meta_data_dict.values(), index=meta_data_dict.keys()) 

617 else: 

618 self.meta_df = None 

619 

620 def compare_models(self, 

621 model_factories: Sequence[Callable[[], Union[VectorRegressionModel, VectorClassificationModel]]], 

622 use_cross_validation=False, 

623 result_writer: Optional[ResultWriter] = None, 

624 write_per_dataset_results=False, 

625 write_csvs=False, 

626 column_name_for_model_ranking: str = None, 

627 rank_max=True, 

628 add_combined_eval_stats=False, 

629 create_metric_distribution_plots=True, 

630 create_combined_eval_stats_plots=False, 

631 distribution_plots_cdf = True, 

632 distribution_plots_cdf_complementary = False, 

633 visitors: Optional[Iterable["ModelComparisonVisitor"]] = None) \ 

634 -> Union["RegressionMultiDataModelComparisonData", "ClassificationMultiDataModelComparisonData"]: 

635 """ 

636 :param model_factories: a sequence of factory functions for the creation of models to evaluate; every factory must result 

637 in a model with a fixed model name (otherwise results cannot be correctly aggregated) 

638 :param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation. 

639 This can only be used if the instance's ``test_io_data_dict`` is None. 

640 :param result_writer: a writer with which to store results; if None, results are not stored 

641 :param write_per_dataset_results: whether to use resultWriter (if not None) in order to generate detailed results for each 

642 dataset in a subdirectory named according to the name of the dataset 

643 :param write_csvs: whether to write metrics table to CSV files 

644 :param column_name_for_model_ranking: column name to use for ranking models 

645 :param rank_max: if true, use max for ranking, else min 

646 :param add_combined_eval_stats: whether to also report, for each model, evaluation metrics on the combined set data points from 

647 all EvalStats objects. 

648 Note that for classification, this is only possible if all individual experiments use the same set of class labels. 

649 :param create_metric_distribution_plots: whether to create, for each model, plots of the distribution of each metric across the 

650 datasets (applies only if result_writer is not None) 

651 :param create_combined_eval_stats_plots: whether to combine, for each type of model, the EvalStats objects from the individual 

652 experiments into a single objects that holds all results and use it to create plots reflecting the overall result (applies only 

653 if resultWriter is not None). 

654 Note that for classification, this is only possible if all individual experiments use the same set of class labels. 

655 :param distribution_plots_cdf: whether to create CDF plots for the metric distributions. Applies only if 

656 create_metric_distribution_plots is True and result_writer is not None. 

657 :param distribution_plots_cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that 

658 distribution_plots_cdf is True. 

659 :param visitors: visitors which may process individual results. Plots generated by visitors are created/collected at the end of the 

660 comparison. 

661 :return: an object containing the full comparison results 

662 """ 

663 if self.test_io_data_dict and use_cross_validation: 

664 raise ValueError("Cannot use cross-validation when `test_io_data_dict` is specified") 

665 

666 all_results_df = pd.DataFrame() 

667 eval_stats_by_model_name = defaultdict(list) 

668 results_by_model_name: Dict[str, List[ModelComparisonData.Result]] = defaultdict(list) 

669 is_regression = None 

670 plot_collector: Optional[EvalStatsPlotCollector] = None 

671 model_names = None 

672 model_name_to_string_repr = None 

673 

674 for i, (key, inputOutputData) in enumerate(self.io_data_dict.items(), start=1): 

675 log.info(f"Evaluating models for data set #{i}/{len(self.io_data_dict)}: {self.key_name}={key}") 

676 models = [f() for f in model_factories] 

677 

678 current_model_names = [model.get_name() for model in models] 

679 if model_names is None: 

680 model_names = current_model_names 

681 elif model_names != current_model_names: 

682 log.warning(f"Model factories do not produce fixed names; use model.withName to name your models. " 

683 f"Got {current_model_names}, previously got {model_names}") 

684 

685 if is_regression is None: 

686 models_are_regression = [model.is_regression_model() for model in models] 

687 if all(models_are_regression): 

688 is_regression = True 

689 elif not any(models_are_regression): 

690 is_regression = False 

691 else: 

692 raise ValueError("The models have to be either all regression models or all classification, not a mixture") 

693 

694 test_io_data = self.test_io_data_dict[key] if self.test_io_data_dict is not None else None 

695 ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params, 

696 cross_validator_params=self.cross_validator_params, test_io_data=test_io_data) 

697 

698 if plot_collector is None: 

699 plot_collector = ev.eval_stats_plot_collector 

700 

701 # compute data frame with results for current data set 

702 if write_per_dataset_results and result_writer is not None: 

703 child_result_writer = result_writer.child_for_subdirectory(key) 

704 else: 

705 child_result_writer = None 

706 comparison_data = ev.compare_models(models, use_cross_validation=use_cross_validation, result_writer=child_result_writer, 

707 visitors=visitors, write_visitor_results=False) 

708 df = comparison_data.results_df 

709 

710 # augment data frame 

711 df[self.key_name] = key 

712 df["model_name"] = df.index 

713 df = df.reset_index(drop=True) 

714 

715 # collect eval stats objects by model name 

716 for modelName, result in comparison_data.result_by_model_name.items(): 

717 if use_cross_validation: 

718 eval_stats = result.cross_validation_data.get_eval_stats_collection().get_global_stats() 

719 else: 

720 eval_stats = result.eval_data.get_eval_stats() 

721 eval_stats_by_model_name[modelName].append(eval_stats) 

722 results_by_model_name[modelName].append(result) 

723 

724 all_results_df = pd.concat((all_results_df, df)) 

725 

726 if model_name_to_string_repr is None: 

727 model_name_to_string_repr = {model.get_name(): model.pprints() for model in models} 

728 

729 if self.meta_df is not None: 

730 all_results_df = all_results_df.join(self.meta_df, on=self.key_name, how="left") 

731 

732 str_all_results = f"All results:\n{all_results_df.to_string()}" 

733 log.info(str_all_results) 

734 

735 # create mean result by model, removing any metrics/columns that produced NaN values 

736 # (because the mean would be computed without them, skipna parameter unsupported) 

737 all_results_grouped = all_results_df.drop(columns=self.key_name).dropna(axis=1).groupby("model_name") 

738 mean_results_df: pd.DataFrame = all_results_grouped.mean() 

739 for colName in [column_name_for_model_ranking, f"mean[{column_name_for_model_ranking}]"]: 

740 if colName in mean_results_df: 

741 mean_results_df.sort_values(column_name_for_model_ranking, inplace=True, ascending=not rank_max) 

742 break 

743 str_mean_results = f"Mean results (averaged across {len(self.io_data_dict)} data sets):\n{mean_results_df.to_string()}" 

744 log.info(str_mean_results) 

745 

746 def iter_combined_eval_stats_from_all_data_sets(): 

747 for model_name, evalStatsList in eval_stats_by_model_name.items(): 

748 if is_regression: 

749 ev_stats = RegressionEvalStatsCollection(evalStatsList).get_global_stats() 

750 else: 

751 ev_stats = ClassificationEvalStatsCollection(evalStatsList).get_global_stats() 

752 yield model_name, ev_stats 

753 

754 # create further aggregations 

755 agg_dfs = [] 

756 for op_name, agg_fn in [("mean", lambda x: x.mean()), ("std", lambda x: x.std()), ("min", lambda x: x.min()), 

757 ("max", lambda x: x.max())]: 

758 agg_df = agg_fn(all_results_grouped) 

759 agg_df.columns = [f"{op_name}[{c}]" for c in agg_df.columns] 

760 agg_dfs.append(agg_df) 

761 further_aggs_df = pd.concat(agg_dfs, axis=1) 

762 further_aggs_df = further_aggs_df.loc[mean_results_df.index] # apply same sort order (index is model_name) 

763 column_order = functools.reduce(lambda a, b: a + b, [list(t) for t in zip(*[df.columns for df in agg_dfs])]) 

764 further_aggs_df = further_aggs_df[column_order] 

765 str_further_aggs = f"Further aggregations:\n{further_aggs_df.to_string()}" 

766 log.info(str_further_aggs) 

767 

768 # combined eval stats from all datasets (per model) 

769 str_combined_eval_stats = "" 

770 if add_combined_eval_stats: 

771 rows = [] 

772 for modelName, eval_stats in iter_combined_eval_stats_from_all_data_sets(): 

773 rows.append({"model_name": modelName, **eval_stats.metrics_dict()}) 

774 combined_stats_df = pd.DataFrame(rows) 

775 combined_stats_df.set_index("model_name", drop=True, inplace=True) 

776 combined_stats_df = combined_stats_df.loc[mean_results_df.index] # apply same sort order (index is model_name) 

777 str_combined_eval_stats = f"Results on combined test data from all data sets:\n{combined_stats_df.to_string()}\n\n" 

778 log.info(str_combined_eval_stats) 

779 

780 if result_writer is not None: 

781 comparison_content = str_mean_results + "\n\n" + str_further_aggs + "\n\n" + str_combined_eval_stats + str_all_results 

782 comparison_content += "\n\nModels [example instance]:\n\n" 

783 comparison_content += "\n\n".join(f"{name} = {s}" for name, s in model_name_to_string_repr.items()) 

784 result_writer.write_text_file("model-comparison-results", comparison_content) 

785 if write_csvs: 

786 result_writer.write_data_frame_csv_file("all-results", all_results_df) 

787 result_writer.write_data_frame_csv_file("mean-results", mean_results_df) 

788 

789 # create plots from combined data for each model 

790 if create_combined_eval_stats_plots: 

791 for modelName, eval_stats in iter_combined_eval_stats_from_all_data_sets(): 

792 child_result_writer = result_writer.child_with_added_prefix(modelName + "_") if result_writer is not None else None 

793 result_collector = EvaluationResultCollector(show_plots=False, result_writer=child_result_writer) 

794 plot_collector.create_plots(eval_stats, subtitle=modelName, result_collector=result_collector) 

795 

796 # collect results from visitors (if any) 

797 result_collector = EvaluationResultCollector(show_plots=False, result_writer=result_writer) 

798 if visitors is not None: 

799 for visitor in visitors: 

800 visitor.collect_results(result_collector) 

801 

802 # create result 

803 dataset_names = list(self.io_data_dict.keys()) 

804 if is_regression: 

805 mdmc_data = RegressionMultiDataModelComparisonData(all_results_df, mean_results_df, further_aggs_df, eval_stats_by_model_name, 

806 results_by_model_name, dataset_names, model_name_to_string_repr) 

807 else: 

808 mdmc_data = ClassificationMultiDataModelComparisonData(all_results_df, mean_results_df, further_aggs_df, 

809 eval_stats_by_model_name, results_by_model_name, dataset_names, model_name_to_string_repr) 

810 

811 # plot distributions 

812 if create_metric_distribution_plots and result_writer is not None: 

813 mdmc_data.create_distribution_plots(result_writer, cdf=distribution_plots_cdf, 

814 cdf_complementary=distribution_plots_cdf_complementary) 

815 

816 return mdmc_data 

817 

818 

819class ModelComparisonData: 

820 @dataclass 

821 class Result: 

822 eval_data: Union[VectorClassificationModelEvaluationData, VectorRegressionModelEvaluationData] = None 

823 cross_validation_data: Union[VectorClassificationModelCrossValidationData, VectorRegressionModelCrossValidationData] = None 

824 

825 def iter_evaluation_data(self) -> Iterator[Union[VectorClassificationModelEvaluationData, VectorRegressionModelEvaluationData]]: 

826 if self.eval_data is not None: 

827 yield self.eval_data 

828 if self.cross_validation_data is not None: 

829 yield from self.cross_validation_data.eval_data_list 

830 

831 def __init__(self, results_df: pd.DataFrame, results_by_model_name: Dict[str, Result], evaluator: Optional[VectorModelEvaluator] = None, 

832 cross_validator: Optional[VectorModelCrossValidator] = None): 

833 self.results_df = results_df 

834 self.result_by_model_name = results_by_model_name 

835 self.evaluator = evaluator 

836 self.cross_validator = cross_validator 

837 

838 def get_best_model_name(self, metric_name: str) -> str: 

839 idx = np.argmax(self.results_df[metric_name]) 

840 return self.results_df.index[idx] 

841 

842 def get_best_model(self, metric_name: str) -> Union[VectorClassificationModel, VectorRegressionModel, VectorModelBase]: 

843 result = self.result_by_model_name[self.get_best_model_name(metric_name)] 

844 if result.eval_data is None: 

845 raise ValueError("The best model is not well-defined when using cross-validation") 

846 return result.eval_data.model 

847 

848 

849class ModelComparisonVisitor(ABC): 

850 @abstractmethod 

851 def visit(self, model_name: str, result: ModelComparisonData.Result): 

852 pass 

853 

854 @abstractmethod 

855 def collect_results(self, result_collector: EvaluationResultCollector) -> None: 

856 """ 

857 Collects results (such as figures) at the end of the model comparison, based on the results collected 

858 

859 :param result_collector: the collector to which figures are to be added 

860 """ 

861 pass 

862 

863 

864class ModelComparisonVisitorAggregatedFeatureImportance(ModelComparisonVisitor): 

865 """ 

866 During a model comparison, computes aggregated feature importance values for the model with the given name 

867 """ 

868 def __init__(self, model_name: str, feature_agg_regex: Sequence[str] = (), write_figure=True, write_data_frame_csv=False): 

869 r""" 

870 :param model_name: the name of the model for which to compute the aggregated feature importance values 

871 :param feature_agg_regex: a sequence of regular expressions describing which feature names to sum as one. Each regex must 

872 contain exactly one group. If a regex matches a feature name, the feature importance will be summed under the key 

873 of the matched group instead of the full feature name. For example, the regex r"(\w+)_\d+$" will cause "foo_1" and "foo_2" 

874 to be summed under "foo" and similarly "bar_1" and "bar_2" to be summed under "bar". 

875 """ 

876 self.model_name = model_name 

877 self.agg_feature_importance = AggregatedFeatureImportance(feature_agg_reg_ex=feature_agg_regex) 

878 self.write_figure = write_figure 

879 self.write_data_frame_csv = write_data_frame_csv 

880 

881 def visit(self, model_name: str, result: ModelComparisonData.Result): 

882 if model_name == self.model_name: 

883 if result.cross_validation_data is not None: 

884 models = result.cross_validation_data.trained_models 

885 if models is not None: 

886 for model in models: 

887 self._collect(model) 

888 else: 

889 raise ValueError("Models were not returned in cross-validation results") 

890 elif result.eval_data is not None: 

891 self._collect(result.eval_data.model) 

892 

893 def _collect(self, model: Union[FeatureImportanceProvider, VectorModelBase]): 

894 if not isinstance(model, FeatureImportanceProvider): 

895 raise ValueError(f"Got model which does inherit from {FeatureImportanceProvider.__qualname__}: {model}") 

896 self.agg_feature_importance.add(model.get_feature_importance_dict()) 

897 

898 @deprecated("Use getFeatureImportance and create the plot using the returned object") 

899 def plot_feature_importance(self) -> plt.Figure: 

900 feature_importance_dict = self.agg_feature_importance.get_aggregated_feature_importance().get_feature_importance_dict() 

901 return plot_feature_importance(feature_importance_dict, subtitle=self.model_name) 

902 

903 def get_feature_importance(self) -> FeatureImportance: 

904 return self.agg_feature_importance.get_aggregated_feature_importance() 

905 

906 def collect_results(self, result_collector: EvaluationResultCollector): 

907 feature_importance = self.get_feature_importance() 

908 if self.write_figure: 

909 result_collector.add_figure(f"{self.model_name}_feature-importance", feature_importance.plot()) 

910 if self.write_data_frame_csv: 

911 result_collector.add_data_frame_csv_file(f"{self.model_name}_feature-importance", feature_importance.get_data_frame()) 

912 

913 

914class MultiDataModelComparisonData(Generic[TEvalStats, TEvalStatsCollection], ABC): 

915 def __init__(self, all_results_df: pd.DataFrame, 

916 mean_results_df: pd.DataFrame, 

917 agg_results_df: pd.DataFrame, 

918 eval_stats_by_model_name: Dict[str, List[TEvalStats]], 

919 results_by_model_name: Dict[str, List[ModelComparisonData.Result]], 

920 dataset_names: List[str], 

921 model_name_to_string_repr: Dict[str, str]): 

922 self.all_results_df = all_results_df 

923 self.mean_results_df = mean_results_df 

924 self.agg_results_df = agg_results_df 

925 self.eval_stats_by_model_name = eval_stats_by_model_name 

926 self.results_by_model_name = results_by_model_name 

927 self.dataset_names = dataset_names 

928 self.model_name_to_string_repr = model_name_to_string_repr 

929 

930 def get_model_names(self) -> List[str]: 

931 return list(self.eval_stats_by_model_name.keys()) 

932 

933 def get_model_description(self, model_name: str) -> str: 

934 return self.model_name_to_string_repr[model_name] 

935 

936 def get_eval_stats_list(self, model_name: str) -> List[TEvalStats]: 

937 return self.eval_stats_by_model_name[model_name] 

938 

939 @abstractmethod 

940 def get_eval_stats_collection(self, model_name: str) -> TEvalStatsCollection: 

941 pass 

942 

943 def iter_model_results(self, model_name: str) -> Iterator[Tuple[str, ModelComparisonData.Result]]: 

944 results = self.results_by_model_name[model_name] 

945 yield from zip(self.dataset_names, results) 

946 

947 def create_distribution_plots(self, result_writer: ResultWriter, cdf=True, cdf_complementary=False): 

948 """ 

949 Creates plots of distributions of metrics across datasets for each model as a histogram, and additionally 

950 any x-y plots (scatter plots & heat maps) for metrics that have associated paired metrics that were also computed 

951 

952 :param result_writer: the result writer 

953 :param cdf: whether to additionally plot, for each distribution, the cumulative distribution function 

954 :param cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that ``cdf`` is True 

955 """ 

956 for modelName in self.get_model_names(): 

957 eval_stats_collection = self.get_eval_stats_collection(modelName) 

958 for metricName in eval_stats_collection.get_metric_names(): 

959 # plot distribution 

960 fig = eval_stats_collection.plot_distribution(metricName, subtitle=modelName, cdf=cdf, cdf_complementary=cdf_complementary) 

961 result_writer.write_figure(f"{modelName}_dist-{metricName}", fig) 

962 # scatter plot with paired metrics 

963 metric: Metric = eval_stats_collection.get_metric_by_name(metricName) 

964 for paired_metric in metric.get_paired_metrics(): 

965 if eval_stats_collection.has_metric(paired_metric): 

966 fig = eval_stats_collection.plot_scatter(metric.name, paired_metric.name) 

967 result_writer.write_figure(f"{modelName}_scatter-{metric.name}-{paired_metric.name}", fig) 

968 fig = eval_stats_collection.plot_heat_map(metric.name, paired_metric.name) 

969 result_writer.write_figure(f"{modelName}_heatmap-{metric.name}-{paired_metric.name}", fig) 

970 

971 

972class ClassificationMultiDataModelComparisonData(MultiDataModelComparisonData[ClassificationEvalStats, ClassificationEvalStatsCollection]): 

973 def get_eval_stats_collection(self, model_name: str): 

974 return ClassificationEvalStatsCollection(self.get_eval_stats_list(model_name)) 

975 

976 

977class RegressionMultiDataModelComparisonData(MultiDataModelComparisonData[RegressionEvalStats, RegressionEvalStatsCollection]): 

978 def get_eval_stats_collection(self, model_name: str): 

979 return RegressionEvalStatsCollection(self.get_eval_stats_list(model_name))