Coverage for src/sensai/evaluation/eval_util.py: 21%

534 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1""" 

2This module contains methods and classes that facilitate evaluation of different types of models. The suggested 

3workflow for evaluation is to use these higher-level functionalities instead of instantiating 

4the evaluation classes directly. 

5""" 

6import functools 

7import logging 

8from abc import ABC, abstractmethod 

9from collections import defaultdict 

10from dataclasses import dataclass 

11from typing import Dict, Any, Union, Generic, TypeVar, Optional, Sequence, Callable, Set, Iterable, List, Iterator, Tuple 

12 

13import matplotlib.figure 

14import matplotlib.pyplot as plt 

15import numpy as np 

16import pandas as pd 

17import seaborn as sns 

18 

19from .crossval import VectorModelCrossValidationData, VectorRegressionModelCrossValidationData, \ 

20 VectorClassificationModelCrossValidationData, \ 

21 VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator, VectorModelCrossValidator, VectorModelCrossValidatorParams 

22from .eval_stats import RegressionEvalStatsCollection, ClassificationEvalStatsCollection, RegressionEvalStatsPlotErrorDistribution, \ 

23 RegressionEvalStatsPlotHeatmapGroundTruthPredictions, RegressionEvalStatsPlotScatterGroundTruthPredictions, \ 

24 ClassificationEvalStatsPlotConfusionMatrix, ClassificationEvalStatsPlotPrecisionRecall, RegressionEvalStatsPlot, \ 

25 ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall, ClassificationEvalStatsPlotProbabilityThresholdCounts, \ 

26 Metric 

27from .eval_stats.eval_stats_base import EvalStats, EvalStatsCollection, EvalStatsPlot 

28from .eval_stats.eval_stats_classification import ClassificationEvalStats 

29from .eval_stats.eval_stats_regression import RegressionEvalStats 

30from .evaluator import VectorModelEvaluator, VectorModelEvaluationData, VectorRegressionModelEvaluator, \ 

31 VectorRegressionModelEvaluationData, VectorClassificationModelEvaluator, VectorClassificationModelEvaluationData, \ 

32 RegressionEvaluatorParams, ClassificationEvaluatorParams 

33from ..data import InputOutputData 

34from ..feature_importance import AggregatedFeatureImportance, FeatureImportanceProvider, plot_feature_importance, FeatureImportance 

35from ..tracking import TrackedExperiment 

36from ..tracking.tracking_base import TrackingContext 

37from ..util.deprecation import deprecated 

38from ..util.io import ResultWriter 

39from ..util.string import pretty_string_repr 

40from ..vector_model import VectorClassificationModel, VectorRegressionModel, VectorModel, VectorModelBase 

41 

42log = logging.getLogger(__name__) 

43 

44TModel = TypeVar("TModel", bound=VectorModel) 

45TEvalStats = TypeVar("TEvalStats", bound=EvalStats) 

46TEvalStatsPlot = TypeVar("TEvalStatsPlot", bound=EvalStatsPlot) 

47TEvalStatsCollection = TypeVar("TEvalStatsCollection", bound=EvalStatsCollection) 

48TEvaluator = TypeVar("TEvaluator", bound=VectorModelEvaluator) 

49TCrossValidator = TypeVar("TCrossValidator", bound=VectorModelCrossValidator) 

50TEvalData = TypeVar("TEvalData", bound=VectorModelEvaluationData) 

51TCrossValData = TypeVar("TCrossValData", bound=VectorModelCrossValidationData) 

52 

53 

54def _is_regression(model: Optional[VectorModel], is_regression: Optional[bool]) -> bool: 

55 if model is None and is_regression is None or (model is not None and is_regression is not None): 

56 raise ValueError("One of the two parameters have to be passed: model or isRegression") 

57 

58 if is_regression is None: 

59 model: VectorModel 

60 return model.is_regression_model() 

61 return is_regression 

62 

63 

64def create_vector_model_evaluator(data: InputOutputData, model: VectorModel = None, 

65 is_regression: bool = None, params: Union[RegressionEvaluatorParams, ClassificationEvaluatorParams] = None, 

66 test_data: Optional[InputOutputData] = None) \ 

67 -> Union[VectorRegressionModelEvaluator, VectorClassificationModelEvaluator]: 

68 is_regression = _is_regression(model, is_regression) 

69 if params is None: 

70 if is_regression: 

71 params = RegressionEvaluatorParams(fractional_split_test_fraction=0.2) 

72 else: 

73 params = ClassificationEvaluatorParams(fractional_split_test_fraction=0.2) 

74 log.debug(f"No evaluator parameters specified, using default: {params}") 

75 if is_regression: 

76 return VectorRegressionModelEvaluator(data, test_data=test_data, params=params) 

77 else: 

78 return VectorClassificationModelEvaluator(data, test_data=test_data, params=params) 

79 

80 

81def create_vector_model_cross_validator(data: InputOutputData, 

82 model: VectorModel = None, 

83 is_regression: bool = None, 

84 params: Union[VectorModelCrossValidatorParams, Dict[str, Any]] = None) \ 

85 -> Union[VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator]: 

86 if params is None: 

87 raise ValueError("params must not be None") 

88 cons = VectorRegressionModelCrossValidator if _is_regression(model, is_regression) else VectorClassificationModelCrossValidator 

89 return cons(data, params=params) 

90 

91 

92def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None, 

93 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None, 

94 cross_validator_params: Optional[Dict[str, Any]] = None, test_io_data: Optional[InputOutputData] = None) \ 

95 -> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]: 

96 if _is_regression(model, is_regression): 

97 return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data) 

98 else: 

99 return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data) 

100 

101 

102def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2, 

103 plot_target_distribution=False, compute_probabilities=True, normalize_plots=True, random_seed=60) -> TEvalData: 

104 """ 

105 Evaluates the given model via a simple evaluation mechanism that uses a single split 

106 

107 :param model: the model to evaluate 

108 :param io_data: data on which to evaluate 

109 :param test_fraction: the fraction of the data to test on 

110 :param plot_target_distribution: whether to plot the target values distribution in the entire dataset 

111 :param compute_probabilities: only relevant if the model is a classifier 

112 :param normalize_plots: whether to normalize plotted distributions such that the sum/integrate to 1 

113 :param random_seed: 

114 

115 :return: the evaluation data 

116 """ 

117 if plot_target_distribution: 

118 title = "Distribution of target values in entire dataset" 

119 fig = plt.figure(title) 

120 

121 output_distribution_series = io_data.outputs.iloc[:, 0] 

122 log.info(f"Description of target column in training set: \n{output_distribution_series.describe()}") 

123 if not model.is_regression_model(): 

124 output_distribution_series = output_distribution_series.value_counts(normalize=normalize_plots) 

125 ax = sns.barplot(output_distribution_series.index, output_distribution_series.values) 

126 ax.set_ylabel("%") 

127 else: 

128 ax = sns.distplot(output_distribution_series) 

129 ax.set_ylabel("Probability density") 

130 ax.set_title(title) 

131 ax.set_xlabel("target value") 

132 fig.show() 

133 

134 if model.is_regression_model(): 

135 evaluator_params = RegressionEvaluatorParams(fractional_split_test_fraction=test_fraction, 

136 fractional_split_random_seed=random_seed) 

137 else: 

138 evaluator_params = ClassificationEvaluatorParams(fractional_split_test_fraction=test_fraction, 

139 compute_probabilities=compute_probabilities, fractional_split_random_seed=random_seed) 

140 ev = create_evaluation_util(io_data, model=model, evaluator_params=evaluator_params) 

141 return ev.perform_simple_evaluation(model, show_plots=True, log_results=True) 

142 

143 

144class EvaluationResultCollector: 

145 def __init__(self, show_plots: bool = True, result_writer: Optional[ResultWriter] = None, 

146 tracking_context: TrackingContext = None): 

147 self.show_plots = show_plots 

148 self.result_writer = result_writer 

149 self.tracking_context = tracking_context 

150 

151 def is_plot_creation_enabled(self) -> bool: 

152 return self.show_plots or self.result_writer is not None or self.tracking_context is not None 

153 

154 def add_figure(self, name: str, fig: matplotlib.figure.Figure): 

155 if self.result_writer is not None: 

156 self.result_writer.write_figure(name, fig, close_figure=False) 

157 if self.tracking_context is not None: 

158 self.tracking_context.track_figure(name, fig) 

159 if not self.show_plots: 

160 plt.close(fig) 

161 

162 def add_data_frame_csv_file(self, name: str, df: pd.DataFrame): 

163 if self.result_writer is not None: 

164 self.result_writer.write_data_frame_csv_file(name, df) 

165 

166 def child(self, added_filename_prefix): 

167 result_writer = self.result_writer 

168 if result_writer: 

169 result_writer = result_writer.child_with_added_prefix(added_filename_prefix) 

170 return self.__class__(show_plots=self.show_plots, result_writer=result_writer) 

171 

172 

173class EvalStatsPlotCollector(Generic[TEvalStats, TEvalStatsPlot]): 

174 def __init__(self): 

175 self.plots: Dict[str, EvalStatsPlot] = {} 

176 self.disabled_plots: Set[str] = set() 

177 

178 def add_plot(self, name: str, plot: EvalStatsPlot): 

179 self.plots[name] = plot 

180 

181 def get_enabled_plots(self) -> List[str]: 

182 return [p for p in self.plots if p not in self.disabled_plots] 

183 

184 def disable_plots(self, *names: str): 

185 self.disabled_plots.update(names) 

186 

187 def create_plots(self, eval_stats: EvalStats, subtitle: str, result_collector: EvaluationResultCollector): 

188 known_plots = set(self.plots.keys()) 

189 unknown_disabled_plots = self.disabled_plots.difference(known_plots) 

190 if len(unknown_disabled_plots) > 0: 

191 log.warning(f"Plots were disabled which are not registered: {unknown_disabled_plots}; known plots: {known_plots}") 

192 for name, plot in self.plots.items(): 

193 if name not in self.disabled_plots: 

194 fig = plot.create_figure(eval_stats, subtitle) 

195 if fig is not None: 

196 result_collector.add_figure(name, fig) 

197 

198 

199class RegressionEvalStatsPlotCollector(EvalStatsPlotCollector[RegressionEvalStats, RegressionEvalStatsPlot]): 

200 def __init__(self): 

201 super().__init__() 

202 self.add_plot("error-dist", RegressionEvalStatsPlotErrorDistribution()) 

203 self.add_plot("heatmap-gt-pred", RegressionEvalStatsPlotHeatmapGroundTruthPredictions()) 

204 self.add_plot("scatter-gt-pred", RegressionEvalStatsPlotScatterGroundTruthPredictions()) 

205 

206 

207class ClassificationEvalStatsPlotCollector(EvalStatsPlotCollector[RegressionEvalStats, RegressionEvalStatsPlot]): 

208 def __init__(self): 

209 super().__init__() 

210 self.add_plot("confusion-matrix-rel", ClassificationEvalStatsPlotConfusionMatrix(normalise=True)) 

211 self.add_plot("confusion-matrix-abs", ClassificationEvalStatsPlotConfusionMatrix(normalise=False)) 

212 # the plots below apply to the binary case only (skipped for non-binary case) 

213 self.add_plot("precision-recall", ClassificationEvalStatsPlotPrecisionRecall()) 

214 self.add_plot("threshold-precision-recall", ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall()) 

215 self.add_plot("threshold-counts", ClassificationEvalStatsPlotProbabilityThresholdCounts()) 

216 

217 

218class ModelEvaluation(ABC, Generic[TModel, TEvaluator, TEvalData, TCrossValidator, TCrossValData, TEvalStats]): 

219 """ 

220 Utility class for the evaluation of models based on a dataset 

221 """ 

222 def __init__(self, io_data: InputOutputData, 

223 eval_stats_plot_collector: Union[RegressionEvalStatsPlotCollector, ClassificationEvalStatsPlotCollector], 

224 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, 

225 Dict[str, Any]]] = None, 

226 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, 

227 test_io_data: Optional[InputOutputData] = None): 

228 """ 

229 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split 

230 into training and test data according to the rules specified by `evaluator_params`. 

231 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is 

232 taken to be the test data when creating evaluators for simple (single-split) evaluation. 

233 :param eval_stats_plot_collector: a collector for plots generated from evaluation stats objects 

234 :param evaluator_params: parameters with which to instantiate evaluators 

235 :param cross_validator_params: parameters with which to instantiate cross-validators 

236 :param test_io_data: optional test data (see `io_data`) 

237 """ 

238 if cross_validator_params is None: 

239 cross_validator_params = VectorModelCrossValidatorParams(folds=5) 

240 self.evaluator_params = evaluator_params 

241 self.cross_validator_params = cross_validator_params 

242 self.io_data = io_data 

243 self.test_io_data = test_io_data 

244 self.eval_stats_plot_collector = eval_stats_plot_collector 

245 

246 def create_evaluator(self, model: TModel = None, is_regression: bool = None) -> TEvaluator: 

247 """ 

248 Creates an evaluator holding the current input-output data 

249 

250 :param model: the model for which to create an evaluator (just for reading off regression or classification, 

251 the resulting evaluator will work on other models as well) 

252 :param is_regression: whether to create a regression model evaluator. Either this or model have to be specified 

253 :return: an evaluator 

254 """ 

255 return create_vector_model_evaluator(self.io_data, model=model, is_regression=is_regression, test_data=self.test_io_data, 

256 params=self.evaluator_params) 

257 

258 def create_cross_validator(self, model: TModel = None, is_regression: bool = None) -> TCrossValidator: 

259 """ 

260 Creates a cross-validator holding the current input-output data 

261 

262 :param model: the model for which to create a cross-validator (just for reading off regression or classification, 

263 the resulting evaluator will work on other models as well) 

264 :param is_regression: whether to create a regression model cross-validator. Either this or model have to be specified 

265 :return: an evaluator 

266 """ 

267 return create_vector_model_cross_validator(self.io_data, model=model, is_regression=is_regression, 

268 params=self.cross_validator_params) 

269 

270 def perform_simple_evaluation(self, model: TModel, 

271 create_plots=True, show_plots=False, 

272 log_results=True, 

273 result_writer: ResultWriter = None, 

274 additional_evaluation_on_training_data=False, 

275 fit_model=True, write_eval_stats=False, 

276 tracked_experiment: TrackedExperiment = None, 

277 evaluator: Optional[TEvaluator] = None) -> TEvalData: 

278 

279 if show_plots and not create_plots: 

280 raise ValueError("showPlots=True requires createPlots=True") 

281 result_writer = self._result_writer_for_model(result_writer, model) 

282 if evaluator is None: 

283 evaluator = self.create_evaluator(model) 

284 if tracked_experiment is not None: 

285 evaluator.set_tracked_experiment(tracked_experiment) 

286 log.info(f"Evaluating {model} via {evaluator}") 

287 

288 def gather_results(result_data: VectorModelEvaluationData, res_writer, subtitle_prefix=""): 

289 str_eval_results = "" 

290 for predictedVarName in result_data.predicted_var_names: 

291 eval_stats = result_data.get_eval_stats(predictedVarName) 

292 str_eval_result = str(eval_stats) 

293 if log_results: 

294 log.info(f"{subtitle_prefix}Evaluation results for {predictedVarName}: {str_eval_result}") 

295 str_eval_results += predictedVarName + ": " + str_eval_result + "\n" 

296 if write_eval_stats and res_writer is not None: 

297 res_writer.write_pickle(f"eval-stats-{predictedVarName}", eval_stats) 

298 str_eval_results += f"\n\n{pretty_string_repr(model)}" 

299 if res_writer is not None: 

300 res_writer.write_text_file("evaluator-results", str_eval_results) 

301 if create_plots: 

302 with TrackingContext.from_optional_experiment(tracked_experiment, model=model) as trackingContext: 

303 self.create_plots(result_data, show_plots=show_plots, result_writer=res_writer, 

304 subtitle_prefix=subtitle_prefix, tracking_context=trackingContext) 

305 

306 eval_result_data = evaluator.eval_model(model, fit=fit_model) 

307 gather_results(eval_result_data, result_writer) 

308 if additional_evaluation_on_training_data: 

309 eval_result_data_train = evaluator.eval_model(model, on_training_data=True, track=False) 

310 additional_result_writer = result_writer.child_with_added_prefix("onTrain-") if result_writer is not None else None 

311 gather_results(eval_result_data_train, additional_result_writer, subtitle_prefix="[onTrain] ") 

312 return eval_result_data 

313 

314 @staticmethod 

315 def _result_writer_for_model(result_writer: Optional[ResultWriter], model: TModel) -> Optional[ResultWriter]: 

316 if result_writer is None: 

317 return None 

318 return result_writer.child_with_added_prefix(model.get_name() + "_") 

319 

320 def perform_cross_validation(self, model: TModel, show_plots=False, log_results=True, result_writer: Optional[ResultWriter] = None, 

321 tracked_experiment: TrackedExperiment = None, cross_validator: Optional[TCrossValidator] = None) -> TCrossValData: 

322 """ 

323 Evaluates the given model via cross-validation 

324 

325 :param model: the model to evaluate 

326 :param show_plots: whether to show plots that visualise evaluation results (combining all folds) 

327 :param log_results: whether to log evaluation results 

328 :param result_writer: a writer with which to store text files and plots. The evaluated model's name is added to each filename 

329 automatically 

330 :param tracked_experiment: a tracked experiment with which results shall be associated 

331 :return: cross-validation result data 

332 :param cross_validator: the cross-validator to apply; if None, a suitable cross-validator will be created 

333 """ 

334 result_writer = self._result_writer_for_model(result_writer, model) 

335 

336 if cross_validator is None: 

337 cross_validator = self.create_cross_validator(model) 

338 if tracked_experiment is not None: 

339 cross_validator.set_tracked_experiment(tracked_experiment) 

340 

341 cross_validation_data = cross_validator.eval_model(model) 

342 

343 agg_stats_by_var = {varName: cross_validation_data.get_eval_stats_collection(predicted_var_name=varName).agg_metrics_dict() 

344 for varName in cross_validation_data.predicted_var_names} 

345 df = pd.DataFrame.from_dict(agg_stats_by_var, orient="index") 

346 

347 str_eval_results = df.to_string() 

348 if log_results: 

349 log.info(f"Cross-validation results:\n{str_eval_results}") 

350 if result_writer is not None: 

351 result_writer.write_text_file("crossval-results", str_eval_results) 

352 

353 with TrackingContext.from_optional_experiment(tracked_experiment, model=model) as trackingContext: 

354 self.create_plots(cross_validation_data, show_plots=show_plots, result_writer=result_writer, 

355 tracking_context=trackingContext) 

356 

357 return cross_validation_data 

358 

359 def compare_models(self, models: Sequence[TModel], result_writer: Optional[ResultWriter] = None, use_cross_validation=False, 

360 fit_models=True, write_individual_results=True, sort_column: Optional[str] = None, sort_ascending: bool = True, 

361 sort_column_move_to_left=True, 

362 also_include_unsorted_results: bool = False, also_include_cross_val_global_stats: bool = False, 

363 visitors: Optional[Iterable["ModelComparisonVisitor"]] = None, 

364 write_visitor_results=False, write_csv=False, 

365 tracked_experiment: Optional[TrackedExperiment] = None) -> "ModelComparisonData": 

366 """ 

367 Compares several models via simple evaluation or cross-validation 

368 

369 :param models: the models to compare 

370 :param result_writer: a writer with which to store results of the comparison 

371 :param use_cross_validation: whether to use cross-validation in order to evaluate models; if False, use a simple evaluation 

372 on test data (single split) 

373 :param fit_models: whether to fit models before evaluating them; this can only be False if useCrossValidation=False 

374 :param write_individual_results: whether to write results files on each individual model (in addition to the comparison 

375 summary) 

376 :param sort_column: column/metric name by which to sort; the fact that the column names change when using cross-validation 

377 (aggregation function names being added) should be ignored, simply pass the (unmodified) metric name 

378 :param sort_ascending: whether to sort using `sortColumn` in ascending order 

379 :param sort_column_move_to_left: whether to move the `sortColumn` (if any) to the very left 

380 :param also_include_unsorted_results: whether to also include, for the case where the results are sorted, the unsorted table of 

381 results in the results text 

382 :param also_include_cross_val_global_stats: whether to also include, when using cross-validation, the evaluation metrics obtained 

383 when combining the predictions from all folds into a single collection. Note that for classification models, 

384 this may not always be possible (if the set of classes know to the model differs across folds) 

385 :param visitors: visitors which may process individual results 

386 :param write_visitor_results: whether to collect results from visitors (if any) after the comparison 

387 :param write_csv: whether to write metrics table to CSV files 

388 :param tracked_experiment: an experiment for tracking 

389 :return: the comparison results 

390 """ 

391 # collect model evaluation results 

392 stats_list = [] 

393 result_by_model_name = {} 

394 evaluator = None 

395 cross_validator = None 

396 for i, model in enumerate(models, start=1): 

397 model_name = model.get_name() 

398 log.info(f"Evaluating model {i}/{len(models)} named '{model_name}' ...") 

399 if use_cross_validation: 

400 if not fit_models: 

401 raise ValueError("Cross-validation necessitates that models be trained several times; got fitModels=False") 

402 if cross_validator is None: 

403 cross_validator = self.create_cross_validator(model) 

404 cross_val_data = self.perform_cross_validation(model, result_writer=result_writer if write_individual_results else None, 

405 cross_validator=cross_validator, tracked_experiment=tracked_experiment) 

406 model_result = ModelComparisonData.Result(cross_validation_data=cross_val_data) 

407 result_by_model_name[model_name] = model_result 

408 eval_stats_collection = cross_val_data.get_eval_stats_collection() 

409 stats_dict = eval_stats_collection.agg_metrics_dict() 

410 else: 

411 if evaluator is None: 

412 evaluator = self.create_evaluator(model) 

413 eval_data = self.perform_simple_evaluation(model, result_writer=result_writer if write_individual_results else None, 

414 fit_model=fit_models, evaluator=evaluator, tracked_experiment=tracked_experiment) 

415 model_result = ModelComparisonData.Result(eval_data=eval_data) 

416 result_by_model_name[model_name] = model_result 

417 eval_stats = eval_data.get_eval_stats() 

418 stats_dict = eval_stats.metrics_dict() 

419 stats_dict["model_name"] = model_name 

420 stats_list.append(stats_dict) 

421 if visitors is not None: 

422 for visitor in visitors: 

423 visitor.visit(model_name, model_result) 

424 results_df = pd.DataFrame(stats_list).set_index("model_name") 

425 

426 # compute results data frame with combined set of data points (for cross-validation only) 

427 cross_val_combined_results_df = None 

428 if use_cross_validation and also_include_cross_val_global_stats: 

429 try: 

430 rows = [] 

431 for model_name, result in result_by_model_name.items(): 

432 stats_dict = result.cross_validation_data.get_eval_stats_collection().get_global_stats().metrics_dict() 

433 stats_dict["model_name"] = model_name 

434 rows.append(stats_dict) 

435 cross_val_combined_results_df = pd.DataFrame(rows).set_index("model_name") 

436 except Exception as e: 

437 log.error(f"Creation of global stats data frame from cross-validation folds failed: {e}") 

438 

439 def sorted_df(df, sort_col): 

440 if sort_col is not None: 

441 if sort_col not in df.columns: 

442 alt_sort_col = f"mean[{sort_col}]" 

443 if alt_sort_col in df.columns: 

444 sort_col = alt_sort_col 

445 else: 

446 sort_col = None 

447 log.warning(f"Requested sort column '{sort_col}' (or '{alt_sort_col}') not in list of columns {list(df.columns)}") 

448 if sort_col is not None: 

449 df = df.sort_values(sort_col, ascending=sort_ascending, inplace=False) 

450 if sort_column_move_to_left: 

451 df = df[[sort_col] + [c for c in df.columns if c != sort_col]] 

452 return df 

453 

454 # write comparison results 

455 title = "Model comparison results" 

456 if use_cross_validation: 

457 title += ", aggregated across folds" 

458 sorted_results_df = sorted_df(results_df, sort_column) 

459 str_results = f"{title}:\n{sorted_results_df.to_string()}" 

460 if also_include_unsorted_results and sort_column is not None: 

461 str_results += f"\n\n{title} (unsorted):\n{results_df.to_string()}" 

462 sorted_cross_val_combined_results_df = None 

463 if cross_val_combined_results_df is not None: 

464 sorted_cross_val_combined_results_df = sorted_df(cross_val_combined_results_df, sort_column) 

465 str_results += f"\n\nModel comparison results based on combined set of data points from all folds:\n" \ 

466 f"{sorted_cross_val_combined_results_df.to_string()}" 

467 log.info(str_results) 

468 if result_writer is not None: 

469 suffix = "crossval" if use_cross_validation else "simple-eval" 

470 str_results += "\n\n" + "\n\n".join([f"{model.get_name()} = {model.pprints()}" for model in models]) 

471 result_writer.write_text_file(f"model-comparison-results-{suffix}", str_results) 

472 if write_csv: 

473 result_writer.write_data_frame_csv_file(f"model-comparison-metrics-{suffix}", sorted_results_df) 

474 if sorted_cross_val_combined_results_df is not None: 

475 result_writer.write_data_frame_csv_file(f"model-comparison-metrics-{suffix}-combined", 

476 sorted_cross_val_combined_results_df) 

477 

478 # write visitor results 

479 if visitors is not None and write_visitor_results: 

480 result_collector = EvaluationResultCollector(show_plots=False, result_writer=result_writer) 

481 for visitor in visitors: 

482 visitor.collect_results(result_collector) 

483 

484 return ModelComparisonData(results_df, result_by_model_name, evaluator=evaluator, cross_validator=cross_validator) 

485 

486 def compare_models_cross_validation(self, models: Sequence[TModel], 

487 result_writer: Optional[ResultWriter] = None) -> "ModelComparisonData": 

488 """ 

489 Compares several models via cross-validation 

490 

491 :param models: the models to compare 

492 :param result_writer: a writer with which to store results of the comparison 

493 :return: the comparison results 

494 """ 

495 return self.compare_models(models, result_writer=result_writer, use_cross_validation=True) 

496 

497 def create_plots(self, data: Union[TEvalData, TCrossValData], show_plots=True, result_writer: Optional[ResultWriter] = None, 

498 subtitle_prefix: str = "", tracking_context: Optional[TrackingContext] = None): 

499 """ 

500 Creates default plots that visualise the results in the given evaluation data 

501 

502 :param data: the evaluation data for which to create the default plots 

503 :param show_plots: whether to show plots 

504 :param result_writer: if not None, plots will be written using this writer 

505 :param subtitle_prefix: a prefix to add to the subtitle (which itself is the model name) 

506 :param tracking_context: the experiment tracking context 

507 """ 

508 result_collector = EvaluationResultCollector(show_plots=show_plots, result_writer=result_writer, 

509 tracking_context=tracking_context) 

510 if result_collector.is_plot_creation_enabled(): 

511 self._create_plots(data, result_collector, subtitle=subtitle_prefix + data.model_name) 

512 

513 def _create_plots(self, data: Union[TEvalData, TCrossValData], result_collector: EvaluationResultCollector, subtitle=None): 

514 

515 def create_plots(pred_var_name, res_collector, subt): 

516 if isinstance(data, VectorModelCrossValidationData): 

517 eval_stats = data.get_eval_stats_collection(predicted_var_name=pred_var_name).get_global_stats() 

518 elif isinstance(data, VectorModelEvaluationData): 

519 eval_stats = data.get_eval_stats(predicted_var_name=pred_var_name) 

520 else: 

521 raise ValueError(f"Unexpected argument: data={data}") 

522 return self._create_eval_stats_plots(eval_stats, res_collector, subtitle=subt) 

523 

524 predicted_var_names = data.predicted_var_names 

525 if len(predicted_var_names) == 1: 

526 create_plots(predicted_var_names[0], result_collector, subtitle) 

527 else: 

528 for predictedVarName in predicted_var_names: 

529 create_plots(predictedVarName, result_collector.child(predictedVarName + "-"), f"{predictedVarName}, {subtitle}") 

530 

531 def _create_eval_stats_plots(self, eval_stats: TEvalStats, result_collector: EvaluationResultCollector, subtitle=None): 

532 """ 

533 :param eval_stats: the evaluation results for which to create plots 

534 :param result_collector: the collector to which all plots are to be passed 

535 :param subtitle: the subtitle to use for generated plots (if any) 

536 """ 

537 self.eval_stats_plot_collector.create_plots(eval_stats, subtitle, result_collector) 

538 

539 

540class RegressionModelEvaluation(ModelEvaluation[VectorRegressionModel, VectorRegressionModelEvaluator, VectorRegressionModelEvaluationData, 

541 VectorRegressionModelCrossValidator, VectorRegressionModelCrossValidationData, RegressionEvalStats]): 

542 def __init__(self, io_data: InputOutputData, 

543 evaluator_params: Optional[Union[RegressionEvaluatorParams, Dict[str, Any]]] = None, 

544 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, 

545 test_io_data: Optional[InputOutputData] = None): 

546 """ 

547 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split 

548 into training and test data according to the rules specified by `evaluator_params`. 

549 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is 

550 taken to be the test data when creating evaluators for simple (single-split) evaluation. 

551 :param evaluator_params: parameters with which to instantiate evaluators 

552 :param cross_validator_params: parameters with which to instantiate cross-validators 

553 :param test_io_data: optional test data (see `io_data`) 

554 """ 

555 super().__init__(io_data, eval_stats_plot_collector=RegressionEvalStatsPlotCollector(), evaluator_params=evaluator_params, 

556 cross_validator_params=cross_validator_params, test_io_data=test_io_data) 

557 

558 

559class ClassificationModelEvaluation(ModelEvaluation[VectorClassificationModel, VectorClassificationModelEvaluator, 

560 VectorClassificationModelEvaluationData, VectorClassificationModelCrossValidator, VectorClassificationModelCrossValidationData, 

561 ClassificationEvalStats]): 

562 def __init__(self, io_data: InputOutputData, 

563 evaluator_params: Optional[Union[ClassificationEvaluatorParams, Dict[str, Any]]] = None, 

564 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, 

565 test_io_data: Optional[InputOutputData] = None): 

566 """ 

567 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split 

568 into training and test data according to the rules specified by `evaluator_params`. 

569 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is 

570 taken to be the test data when creating evaluators for simple (single-split) evaluation. 

571 :param evaluator_params: parameters with which to instantiate evaluators 

572 :param cross_validator_params: parameters with which to instantiate cross-validators 

573 :param test_io_data: optional test data (see `io_data`) 

574 """ 

575 super().__init__(io_data, eval_stats_plot_collector=ClassificationEvalStatsPlotCollector(), evaluator_params=evaluator_params, 

576 cross_validator_params=cross_validator_params, test_io_data=test_io_data) 

577 

578 

579class MultiDataModelEvaluation: 

580 def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset", 

581 meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None, 

582 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None, 

583 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, 

584 test_io_data_dict: Optional[Dict[str, Optional[InputOutputData]]] = None): 

585 """ 

586 :param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models. 

587 For evaluation or cross-validation, these datasets will usually be split according to the rules 

588 specified by `evaluator_params or `cross_validator_params`. An exception is the case where 

589 explicit test data sets are specified by passing `test_io_data_dict`. Then, for these data 

590 sets, the io_data will not be split for evaluation, but the test_io_data will be used instead. 

591 :param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames 

592 :param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps 

593 from a column name to a value and which is to be used to extend the result data frames containing per-dataset results 

594 :param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False) 

595 :param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True) 

596 :param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation or to None. 

597 Entries with non-None values will be used for evaluation of the models that were trained on the respective io_data_dict. 

598 If passed, the keys need to be a superset of io_data_dict's keys (note that the values may be None, e.g. 

599 if you want to use test data sets for some entries, and splitting of the io_data for others). 

600 If not None, cross-validation cannot be used when calling ``compare_models``. 

601 """ 

602 if test_io_data_dict is not None: 

603 missing_keys = set(io_data_dict).difference(test_io_data_dict) 

604 if len(missing_keys) > 0: 

605 raise ValueError( 

606 "If test_io_data_dict is passed, its keys must be a superset of the io_data_dict's keys." 

607 f"However, found missing_keys: {missing_keys}") 

608 self.io_data_dict = io_data_dict 

609 self.test_io_data_dict = test_io_data_dict 

610 

611 self.key_name = key_name 

612 self.evaluator_params = evaluator_params 

613 self.cross_validator_params = cross_validator_params 

614 if meta_data_dict is not None: 

615 self.meta_df = pd.DataFrame(meta_data_dict.values(), index=meta_data_dict.keys()) 

616 else: 

617 self.meta_df = None 

618 

619 def compare_models(self, 

620 model_factories: Sequence[Callable[[], Union[VectorRegressionModel, VectorClassificationModel]]], 

621 use_cross_validation=False, 

622 result_writer: Optional[ResultWriter] = None, 

623 write_per_dataset_results=False, 

624 write_csvs=False, 

625 column_name_for_model_ranking: str = None, 

626 rank_max=True, 

627 add_combined_eval_stats=False, 

628 create_metric_distribution_plots=True, 

629 create_combined_eval_stats_plots=False, 

630 distribution_plots_cdf = True, 

631 distribution_plots_cdf_complementary = False, 

632 visitors: Optional[Iterable["ModelComparisonVisitor"]] = None) \ 

633 -> Union["RegressionMultiDataModelComparisonData", "ClassificationMultiDataModelComparisonData"]: 

634 """ 

635 :param model_factories: a sequence of factory functions for the creation of models to evaluate; every factory must result 

636 in a model with a fixed model name (otherwise results cannot be correctly aggregated) 

637 :param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation. 

638 This can only be used if the instance's ``test_io_data_dict`` is None. 

639 :param result_writer: a writer with which to store results; if None, results are not stored 

640 :param write_per_dataset_results: whether to use resultWriter (if not None) in order to generate detailed results for each 

641 dataset in a subdirectory named according to the name of the dataset 

642 :param write_csvs: whether to write metrics table to CSV files 

643 :param column_name_for_model_ranking: column name to use for ranking models 

644 :param rank_max: if true, use max for ranking, else min 

645 :param add_combined_eval_stats: whether to also report, for each model, evaluation metrics on the combined set data points from 

646 all EvalStats objects. 

647 Note that for classification, this is only possible if all individual experiments use the same set of class labels. 

648 :param create_metric_distribution_plots: whether to create, for each model, plots of the distribution of each metric across the 

649 datasets (applies only if result_writer is not None) 

650 :param create_combined_eval_stats_plots: whether to combine, for each type of model, the EvalStats objects from the individual 

651 experiments into a single objects that holds all results and use it to create plots reflecting the overall result (applies only 

652 if resultWriter is not None). 

653 Note that for classification, this is only possible if all individual experiments use the same set of class labels. 

654 :param distribution_plots_cdf: whether to create CDF plots for the metric distributions. Applies only if 

655 create_metric_distribution_plots is True and result_writer is not None. 

656 :param distribution_plots_cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that 

657 distribution_plots_cdf is True. 

658 :param visitors: visitors which may process individual results. Plots generated by visitors are created/collected at the end of the 

659 comparison. 

660 :return: an object containing the full comparison results 

661 """ 

662 if self.test_io_data_dict and use_cross_validation: 

663 raise ValueError("Cannot use cross-validation when `test_io_data_dict` is specified") 

664 

665 all_results_df = pd.DataFrame() 

666 eval_stats_by_model_name = defaultdict(list) 

667 results_by_model_name: Dict[str, List[ModelComparisonData.Result]] = defaultdict(list) 

668 is_regression = None 

669 plot_collector: Optional[EvalStatsPlotCollector] = None 

670 model_names = None 

671 model_name_to_string_repr = None 

672 

673 for i, (key, inputOutputData) in enumerate(self.io_data_dict.items(), start=1): 

674 log.info(f"Evaluating models for data set #{i}/{len(self.io_data_dict)}: {self.key_name}={key}") 

675 models = [f() for f in model_factories] 

676 

677 current_model_names = [model.get_name() for model in models] 

678 if model_names is None: 

679 model_names = current_model_names 

680 elif model_names != current_model_names: 

681 log.warning(f"Model factories do not produce fixed names; use model.withName to name your models. " 

682 f"Got {current_model_names}, previously got {model_names}") 

683 

684 if is_regression is None: 

685 models_are_regression = [model.is_regression_model() for model in models] 

686 if all(models_are_regression): 

687 is_regression = True 

688 elif not any(models_are_regression): 

689 is_regression = False 

690 else: 

691 raise ValueError("The models have to be either all regression models or all classification, not a mixture") 

692 

693 test_io_data = self.test_io_data_dict[key] if self.test_io_data_dict is not None else None 

694 ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params, 

695 cross_validator_params=self.cross_validator_params, test_io_data=test_io_data) 

696 

697 if plot_collector is None: 

698 plot_collector = ev.eval_stats_plot_collector 

699 

700 # compute data frame with results for current data set 

701 if write_per_dataset_results and result_writer is not None: 

702 child_result_writer = result_writer.child_for_subdirectory(key) 

703 else: 

704 child_result_writer = None 

705 comparison_data = ev.compare_models(models, use_cross_validation=use_cross_validation, result_writer=child_result_writer, 

706 visitors=visitors, write_visitor_results=False) 

707 df = comparison_data.results_df 

708 

709 # augment data frame 

710 df[self.key_name] = key 

711 df["model_name"] = df.index 

712 df = df.reset_index(drop=True) 

713 

714 # collect eval stats objects by model name 

715 for modelName, result in comparison_data.result_by_model_name.items(): 

716 if use_cross_validation: 

717 eval_stats = result.cross_validation_data.get_eval_stats_collection().get_global_stats() 

718 else: 

719 eval_stats = result.eval_data.get_eval_stats() 

720 eval_stats_by_model_name[modelName].append(eval_stats) 

721 results_by_model_name[modelName].append(result) 

722 

723 all_results_df = pd.concat((all_results_df, df)) 

724 

725 if model_name_to_string_repr is None: 

726 model_name_to_string_repr = {model.get_name(): model.pprints() for model in models} 

727 

728 if self.meta_df is not None: 

729 all_results_df = all_results_df.join(self.meta_df, on=self.key_name, how="left") 

730 

731 str_all_results = f"All results:\n{all_results_df.to_string()}" 

732 log.info(str_all_results) 

733 

734 # create mean result by model, removing any metrics/columns that produced NaN values 

735 # (because the mean would be computed without them, skipna parameter unsupported) 

736 all_results_grouped = all_results_df.drop(columns=self.key_name).dropna(axis=1).groupby("model_name") 

737 mean_results_df: pd.DataFrame = all_results_grouped.mean() 

738 for colName in [column_name_for_model_ranking, f"mean[{column_name_for_model_ranking}]"]: 

739 if colName in mean_results_df: 

740 mean_results_df.sort_values(column_name_for_model_ranking, inplace=True, ascending=not rank_max) 

741 break 

742 str_mean_results = f"Mean results (averaged across {len(self.io_data_dict)} data sets):\n{mean_results_df.to_string()}" 

743 log.info(str_mean_results) 

744 

745 def iter_combined_eval_stats_from_all_data_sets(): 

746 for model_name, evalStatsList in eval_stats_by_model_name.items(): 

747 if is_regression: 

748 ev_stats = RegressionEvalStatsCollection(evalStatsList).get_global_stats() 

749 else: 

750 ev_stats = ClassificationEvalStatsCollection(evalStatsList).get_global_stats() 

751 yield model_name, ev_stats 

752 

753 # create further aggregations 

754 agg_dfs = [] 

755 for op_name, agg_fn in [("mean", lambda x: x.mean()), ("std", lambda x: x.std()), ("min", lambda x: x.min()), 

756 ("max", lambda x: x.max())]: 

757 agg_df = agg_fn(all_results_grouped) 

758 agg_df.columns = [f"{op_name}[{c}]" for c in agg_df.columns] 

759 agg_dfs.append(agg_df) 

760 further_aggs_df = pd.concat(agg_dfs, axis=1) 

761 further_aggs_df = further_aggs_df.loc[mean_results_df.index] # apply same sort order (index is model_name) 

762 column_order = functools.reduce(lambda a, b: a + b, [list(t) for t in zip(*[df.columns for df in agg_dfs])]) 

763 further_aggs_df = further_aggs_df[column_order] 

764 str_further_aggs = f"Further aggregations:\n{further_aggs_df.to_string()}" 

765 log.info(str_further_aggs) 

766 

767 # combined eval stats from all datasets (per model) 

768 str_combined_eval_stats = "" 

769 if add_combined_eval_stats: 

770 rows = [] 

771 for modelName, eval_stats in iter_combined_eval_stats_from_all_data_sets(): 

772 rows.append({"model_name": modelName, **eval_stats.metrics_dict()}) 

773 combined_stats_df = pd.DataFrame(rows) 

774 combined_stats_df.set_index("model_name", drop=True, inplace=True) 

775 combined_stats_df = combined_stats_df.loc[mean_results_df.index] # apply same sort order (index is model_name) 

776 str_combined_eval_stats = f"Results on combined test data from all data sets:\n{combined_stats_df.to_string()}\n\n" 

777 log.info(str_combined_eval_stats) 

778 

779 if result_writer is not None: 

780 comparison_content = str_mean_results + "\n\n" + str_further_aggs + "\n\n" + str_combined_eval_stats + str_all_results 

781 comparison_content += "\n\nModels [example instance]:\n\n" 

782 comparison_content += "\n\n".join(f"{name} = {s}" for name, s in model_name_to_string_repr.items()) 

783 result_writer.write_text_file("model-comparison-results", comparison_content) 

784 if write_csvs: 

785 result_writer.write_data_frame_csv_file("all-results", all_results_df) 

786 result_writer.write_data_frame_csv_file("mean-results", mean_results_df) 

787 

788 # create plots from combined data for each model 

789 if create_combined_eval_stats_plots: 

790 for modelName, eval_stats in iter_combined_eval_stats_from_all_data_sets(): 

791 child_result_writer = result_writer.child_with_added_prefix(modelName + "_") if result_writer is not None else None 

792 result_collector = EvaluationResultCollector(show_plots=False, result_writer=child_result_writer) 

793 plot_collector.create_plots(eval_stats, subtitle=modelName, result_collector=result_collector) 

794 

795 # collect results from visitors (if any) 

796 result_collector = EvaluationResultCollector(show_plots=False, result_writer=result_writer) 

797 if visitors is not None: 

798 for visitor in visitors: 

799 visitor.collect_results(result_collector) 

800 

801 # create result 

802 dataset_names = list(self.io_data_dict.keys()) 

803 if is_regression: 

804 mdmc_data = RegressionMultiDataModelComparisonData(all_results_df, mean_results_df, further_aggs_df, eval_stats_by_model_name, 

805 results_by_model_name, dataset_names, model_name_to_string_repr) 

806 else: 

807 mdmc_data = ClassificationMultiDataModelComparisonData(all_results_df, mean_results_df, further_aggs_df, 

808 eval_stats_by_model_name, results_by_model_name, dataset_names, model_name_to_string_repr) 

809 

810 # plot distributions 

811 if create_metric_distribution_plots and result_writer is not None: 

812 mdmc_data.create_distribution_plots(result_writer, cdf=distribution_plots_cdf, 

813 cdf_complementary=distribution_plots_cdf_complementary) 

814 

815 return mdmc_data 

816 

817 

818class ModelComparisonData: 

819 @dataclass 

820 class Result: 

821 eval_data: Union[VectorClassificationModelEvaluationData, VectorRegressionModelEvaluationData] = None 

822 cross_validation_data: Union[VectorClassificationModelCrossValidationData, VectorRegressionModelCrossValidationData] = None 

823 

824 def iter_evaluation_data(self) -> Iterator[Union[VectorClassificationModelEvaluationData, VectorRegressionModelEvaluationData]]: 

825 if self.eval_data is not None: 

826 yield self.eval_data 

827 if self.cross_validation_data is not None: 

828 yield from self.cross_validation_data.eval_data_list 

829 

830 def __init__(self, results_df: pd.DataFrame, results_by_model_name: Dict[str, Result], evaluator: Optional[VectorModelEvaluator] = None, 

831 cross_validator: Optional[VectorModelCrossValidator] = None): 

832 self.results_df = results_df 

833 self.result_by_model_name = results_by_model_name 

834 self.evaluator = evaluator 

835 self.cross_validator = cross_validator 

836 

837 def get_best_model_name(self, metric_name: str) -> str: 

838 idx = np.argmax(self.results_df[metric_name]) 

839 return self.results_df.index[idx] 

840 

841 def get_best_model(self, metric_name: str) -> Union[VectorClassificationModel, VectorRegressionModel, VectorModelBase]: 

842 result = self.result_by_model_name[self.get_best_model_name(metric_name)] 

843 if result.eval_data is None: 

844 raise ValueError("The best model is not well-defined when using cross-validation") 

845 return result.eval_data.model 

846 

847 

848class ModelComparisonVisitor(ABC): 

849 @abstractmethod 

850 def visit(self, model_name: str, result: ModelComparisonData.Result): 

851 pass 

852 

853 @abstractmethod 

854 def collect_results(self, result_collector: EvaluationResultCollector) -> None: 

855 """ 

856 Collects results (such as figures) at the end of the model comparison, based on the results collected 

857 

858 :param result_collector: the collector to which figures are to be added 

859 """ 

860 pass 

861 

862 

863class ModelComparisonVisitorAggregatedFeatureImportance(ModelComparisonVisitor): 

864 """ 

865 During a model comparison, computes aggregated feature importance values for the model with the given name 

866 """ 

867 def __init__(self, model_name: str, feature_agg_regex: Sequence[str] = (), write_figure=True, write_data_frame_csv=False): 

868 r""" 

869 :param model_name: the name of the model for which to compute the aggregated feature importance values 

870 :param feature_agg_regex: a sequence of regular expressions describing which feature names to sum as one. Each regex must 

871 contain exactly one group. If a regex matches a feature name, the feature importance will be summed under the key 

872 of the matched group instead of the full feature name. For example, the regex r"(\w+)_\d+$" will cause "foo_1" and "foo_2" 

873 to be summed under "foo" and similarly "bar_1" and "bar_2" to be summed under "bar". 

874 """ 

875 self.model_name = model_name 

876 self.agg_feature_importance = AggregatedFeatureImportance(feature_agg_reg_ex=feature_agg_regex) 

877 self.write_figure = write_figure 

878 self.write_data_frame_csv = write_data_frame_csv 

879 

880 def visit(self, model_name: str, result: ModelComparisonData.Result): 

881 if model_name == self.model_name: 

882 if result.cross_validation_data is not None: 

883 models = result.cross_validation_data.trained_models 

884 if models is not None: 

885 for model in models: 

886 self._collect(model) 

887 else: 

888 raise ValueError("Models were not returned in cross-validation results") 

889 elif result.eval_data is not None: 

890 self._collect(result.eval_data.model) 

891 

892 def _collect(self, model: Union[FeatureImportanceProvider, VectorModelBase]): 

893 if not isinstance(model, FeatureImportanceProvider): 

894 raise ValueError(f"Got model which does inherit from {FeatureImportanceProvider.__qualname__}: {model}") 

895 self.agg_feature_importance.add(model.get_feature_importance_dict()) 

896 

897 @deprecated("Use getFeatureImportance and create the plot using the returned object") 

898 def plot_feature_importance(self) -> plt.Figure: 

899 feature_importance_dict = self.agg_feature_importance.get_aggregated_feature_importance().get_feature_importance_dict() 

900 return plot_feature_importance(feature_importance_dict, subtitle=self.model_name) 

901 

902 def get_feature_importance(self) -> FeatureImportance: 

903 return self.agg_feature_importance.get_aggregated_feature_importance() 

904 

905 def collect_results(self, result_collector: EvaluationResultCollector): 

906 feature_importance = self.get_feature_importance() 

907 if self.write_figure: 

908 result_collector.add_figure(f"{self.model_name}_feature-importance", feature_importance.plot()) 

909 if self.write_data_frame_csv: 

910 result_collector.add_data_frame_csv_file(f"{self.model_name}_feature-importance", feature_importance.get_data_frame()) 

911 

912 

913class MultiDataModelComparisonData(Generic[TEvalStats, TEvalStatsCollection], ABC): 

914 def __init__(self, all_results_df: pd.DataFrame, 

915 mean_results_df: pd.DataFrame, 

916 agg_results_df: pd.DataFrame, 

917 eval_stats_by_model_name: Dict[str, List[TEvalStats]], 

918 results_by_model_name: Dict[str, List[ModelComparisonData.Result]], 

919 dataset_names: List[str], 

920 model_name_to_string_repr: Dict[str, str]): 

921 self.all_results_df = all_results_df 

922 self.mean_results_df = mean_results_df 

923 self.agg_results_df = agg_results_df 

924 self.eval_stats_by_model_name = eval_stats_by_model_name 

925 self.results_by_model_name = results_by_model_name 

926 self.dataset_names = dataset_names 

927 self.model_name_to_string_repr = model_name_to_string_repr 

928 

929 def get_model_names(self) -> List[str]: 

930 return list(self.eval_stats_by_model_name.keys()) 

931 

932 def get_model_description(self, model_name: str) -> str: 

933 return self.model_name_to_string_repr[model_name] 

934 

935 def get_eval_stats_list(self, model_name: str) -> List[TEvalStats]: 

936 return self.eval_stats_by_model_name[model_name] 

937 

938 @abstractmethod 

939 def get_eval_stats_collection(self, model_name: str) -> TEvalStatsCollection: 

940 pass 

941 

942 def iter_model_results(self, model_name: str) -> Iterator[Tuple[str, ModelComparisonData.Result]]: 

943 results = self.results_by_model_name[model_name] 

944 yield from zip(self.dataset_names, results) 

945 

946 def create_distribution_plots(self, result_writer: ResultWriter, cdf=True, cdf_complementary=False): 

947 """ 

948 Creates plots of distributions of metrics across datasets for each model as a histogram, and additionally 

949 any x-y plots (scatter plots & heat maps) for metrics that have associated paired metrics that were also computed 

950 

951 :param result_writer: the result writer 

952 :param cdf: whether to additionally plot, for each distribution, the cumulative distribution function 

953 :param cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that ``cdf`` is True 

954 """ 

955 for modelName in self.get_model_names(): 

956 eval_stats_collection = self.get_eval_stats_collection(modelName) 

957 for metricName in eval_stats_collection.get_metric_names(): 

958 # plot distribution 

959 fig = eval_stats_collection.plot_distribution(metricName, subtitle=modelName, cdf=cdf, cdf_complementary=cdf_complementary) 

960 result_writer.write_figure(f"{modelName}_dist-{metricName}", fig) 

961 # scatter plot with paired metrics 

962 metric: Metric = eval_stats_collection.get_metric_by_name(metricName) 

963 for paired_metric in metric.get_paired_metrics(): 

964 if eval_stats_collection.has_metric(paired_metric): 

965 fig = eval_stats_collection.plot_scatter(metric.name, paired_metric.name) 

966 result_writer.write_figure(f"{modelName}_scatter-{metric.name}-{paired_metric.name}", fig) 

967 fig = eval_stats_collection.plot_heat_map(metric.name, paired_metric.name) 

968 result_writer.write_figure(f"{modelName}_heatmap-{metric.name}-{paired_metric.name}", fig) 

969 

970 

971class ClassificationMultiDataModelComparisonData(MultiDataModelComparisonData[ClassificationEvalStats, ClassificationEvalStatsCollection]): 

972 def get_eval_stats_collection(self, model_name: str): 

973 return ClassificationEvalStatsCollection(self.get_eval_stats_list(model_name)) 

974 

975 

976class RegressionMultiDataModelComparisonData(MultiDataModelComparisonData[RegressionEvalStats, RegressionEvalStatsCollection]): 

977 def get_eval_stats_collection(self, model_name: str): 

978 return RegressionEvalStatsCollection(self.get_eval_stats_list(model_name))