Coverage for src/sensai/evaluation/eval_util.py: 21%
535 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1"""
2This module contains methods and classes that facilitate evaluation of different types of models. The suggested
3workflow for evaluation is to use these higher-level functionalities instead of instantiating
4the evaluation classes directly.
5"""
6import functools
7import logging
8from abc import ABC, abstractmethod
9from collections import defaultdict
10from dataclasses import dataclass
11from typing import Dict, Any, Union, Generic, TypeVar, Optional, Sequence, Callable, Set, Iterable, List, Iterator, Tuple
13import matplotlib.figure
14import matplotlib.pyplot as plt
15import numpy as np
16import pandas as pd
17import seaborn as sns
19from .crossval import VectorModelCrossValidationData, VectorRegressionModelCrossValidationData, \
20 VectorClassificationModelCrossValidationData, \
21 VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator, VectorModelCrossValidator, VectorModelCrossValidatorParams
22from .eval_stats import RegressionEvalStatsCollection, ClassificationEvalStatsCollection, RegressionEvalStatsPlotErrorDistribution, \
23 RegressionEvalStatsPlotHeatmapGroundTruthPredictions, RegressionEvalStatsPlotScatterGroundTruthPredictions, \
24 ClassificationEvalStatsPlotConfusionMatrix, ClassificationEvalStatsPlotPrecisionRecall, RegressionEvalStatsPlot, \
25 ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall, ClassificationEvalStatsPlotProbabilityThresholdCounts, \
26 Metric
27from .eval_stats.eval_stats_base import EvalStats, EvalStatsCollection, EvalStatsPlot
28from .eval_stats.eval_stats_classification import ClassificationEvalStats
29from .eval_stats.eval_stats_regression import RegressionEvalStats
30from .evaluator import VectorModelEvaluator, VectorModelEvaluationData, VectorRegressionModelEvaluator, \
31 VectorRegressionModelEvaluationData, VectorClassificationModelEvaluator, VectorClassificationModelEvaluationData, \
32 RegressionEvaluatorParams, ClassificationEvaluatorParams
33from ..data import InputOutputData
34from ..feature_importance import AggregatedFeatureImportance, FeatureImportanceProvider, plot_feature_importance, FeatureImportance
35from ..tracking import TrackedExperiment
36from ..tracking.tracking_base import TrackingContext
37from ..util.deprecation import deprecated
38from ..util.io import ResultWriter
39from ..util.string import pretty_string_repr
40from ..vector_model import VectorClassificationModel, VectorRegressionModel, VectorModel, VectorModelBase
42log = logging.getLogger(__name__)
44TModel = TypeVar("TModel", bound=VectorModel)
45TEvalStats = TypeVar("TEvalStats", bound=EvalStats)
46TEvalStatsPlot = TypeVar("TEvalStatsPlot", bound=EvalStatsPlot)
47TEvalStatsCollection = TypeVar("TEvalStatsCollection", bound=EvalStatsCollection)
48TEvaluator = TypeVar("TEvaluator", bound=VectorModelEvaluator)
49TCrossValidator = TypeVar("TCrossValidator", bound=VectorModelCrossValidator)
50TEvalData = TypeVar("TEvalData", bound=VectorModelEvaluationData)
51TCrossValData = TypeVar("TCrossValData", bound=VectorModelCrossValidationData)
54def _is_regression(model: Optional[VectorModel], is_regression: Optional[bool]) -> bool:
55 if model is None and is_regression is None or (model is not None and is_regression is not None):
56 raise ValueError("One of the two parameters have to be passed: model or isRegression")
58 if is_regression is None:
59 model: VectorModel
60 return model.is_regression_model()
61 return is_regression
64def create_vector_model_evaluator(data: InputOutputData, model: VectorModel = None,
65 is_regression: bool = None, params: Union[RegressionEvaluatorParams, ClassificationEvaluatorParams] = None,
66 test_data: Optional[InputOutputData] = None) \
67 -> Union[VectorRegressionModelEvaluator, VectorClassificationModelEvaluator]:
68 is_regression = _is_regression(model, is_regression)
69 if params is None:
70 if is_regression:
71 params = RegressionEvaluatorParams(fractional_split_test_fraction=0.2)
72 else:
73 params = ClassificationEvaluatorParams(fractional_split_test_fraction=0.2)
74 log.debug(f"No evaluator parameters specified, using default: {params}")
75 if is_regression:
76 return VectorRegressionModelEvaluator(data, test_data=test_data, params=params)
77 else:
78 return VectorClassificationModelEvaluator(data, test_data=test_data, params=params)
81def create_vector_model_cross_validator(data: InputOutputData,
82 model: VectorModel = None,
83 is_regression: bool = None,
84 params: Union[VectorModelCrossValidatorParams, Dict[str, Any]] = None) \
85 -> Union[VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator]:
86 if params is None:
87 raise ValueError("params must not be None")
88 cons = VectorRegressionModelCrossValidator if _is_regression(model, is_regression) else VectorClassificationModelCrossValidator
89 return cons(data, params=params)
92def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None,
93 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None,
94 cross_validator_params: Optional[Dict[str, Any]] = None, test_io_data: Optional[InputOutputData] = None) \
95 -> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]:
96 if _is_regression(model, is_regression):
97 return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)
98 else:
99 return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)
102def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2,
103 plot_target_distribution=False, compute_probabilities=True, normalize_plots=True, random_seed=60) -> TEvalData:
104 """
105 Evaluates the given model via a simple evaluation mechanism that uses a single split
107 :param model: the model to evaluate
108 :param io_data: data on which to evaluate
109 :param test_fraction: the fraction of the data to test on
110 :param plot_target_distribution: whether to plot the target values distribution in the entire dataset
111 :param compute_probabilities: only relevant if the model is a classifier
112 :param normalize_plots: whether to normalize plotted distributions such that the sum/integrate to 1
113 :param random_seed:
115 :return: the evaluation data
116 """
117 if plot_target_distribution:
118 title = "Distribution of target values in entire dataset"
119 fig = plt.figure(title)
121 output_distribution_series = io_data.outputs.iloc[:, 0]
122 log.info(f"Description of target column in training set: \n{output_distribution_series.describe()}")
123 if not model.is_regression_model():
124 output_distribution_series = output_distribution_series.value_counts(normalize=normalize_plots)
125 ax = sns.barplot(output_distribution_series.index, output_distribution_series.values)
126 ax.set_ylabel("%")
127 else:
128 ax = sns.distplot(output_distribution_series)
129 ax.set_ylabel("Probability density")
130 ax.set_title(title)
131 ax.set_xlabel("target value")
132 fig.show()
134 if model.is_regression_model():
135 evaluator_params = RegressionEvaluatorParams(fractional_split_test_fraction=test_fraction,
136 fractional_split_random_seed=random_seed)
137 else:
138 evaluator_params = ClassificationEvaluatorParams(fractional_split_test_fraction=test_fraction,
139 compute_probabilities=compute_probabilities, fractional_split_random_seed=random_seed)
140 ev = create_evaluation_util(io_data, model=model, evaluator_params=evaluator_params)
141 return ev.perform_simple_evaluation(model, show_plots=True, log_results=True)
144class EvaluationResultCollector:
145 def __init__(self, show_plots: bool = True, result_writer: Optional[ResultWriter] = None,
146 tracking_context: TrackingContext = None):
147 self.show_plots = show_plots
148 self.result_writer = result_writer
149 self.tracking_context = tracking_context
151 def is_plot_creation_enabled(self) -> bool:
152 return self.show_plots or self.result_writer is not None or self.tracking_context is not None
154 def add_figure(self, name: str, fig: matplotlib.figure.Figure):
155 if self.result_writer is not None:
156 self.result_writer.write_figure(name, fig, close_figure=False)
157 if self.tracking_context is not None:
158 self.tracking_context.track_figure(name, fig)
159 if not self.show_plots:
160 plt.close(fig)
162 def add_data_frame_csv_file(self, name: str, df: pd.DataFrame):
163 if self.result_writer is not None:
164 self.result_writer.write_data_frame_csv_file(name, df)
166 def child(self, added_filename_prefix):
167 result_writer = self.result_writer
168 if result_writer:
169 result_writer = result_writer.child_with_added_prefix(added_filename_prefix)
170 return self.__class__(show_plots=self.show_plots, result_writer=result_writer)
173class EvalStatsPlotCollector(Generic[TEvalStats, TEvalStatsPlot]):
174 def __init__(self):
175 self.plots: Dict[str, EvalStatsPlot] = {}
176 self.disabled_plots: Set[str] = set()
178 def add_plot(self, name: str, plot: EvalStatsPlot):
179 self.plots[name] = plot
181 def get_enabled_plots(self) -> List[str]:
182 return [p for p in self.plots if p not in self.disabled_plots]
184 def disable_plots(self, *names: str):
185 self.disabled_plots.update(names)
187 def create_plots(self, eval_stats: EvalStats, subtitle: str, result_collector: EvaluationResultCollector):
188 known_plots = set(self.plots.keys())
189 unknown_disabled_plots = self.disabled_plots.difference(known_plots)
190 if len(unknown_disabled_plots) > 0:
191 log.warning(f"Plots were disabled which are not registered: {unknown_disabled_plots}; known plots: {known_plots}")
192 for name, plot in self.plots.items():
193 if name not in self.disabled_plots and plot.is_applicable(eval_stats):
194 fig = plot.create_figure(eval_stats, subtitle)
195 if fig is not None:
196 result_collector.add_figure(name, fig)
199class RegressionEvalStatsPlotCollector(EvalStatsPlotCollector[RegressionEvalStats, RegressionEvalStatsPlot]):
200 def __init__(self):
201 super().__init__()
202 self.add_plot("error-dist", RegressionEvalStatsPlotErrorDistribution())
203 self.add_plot("heatmap-gt-pred", RegressionEvalStatsPlotHeatmapGroundTruthPredictions(weighted=False))
204 self.add_plot("heatmap-gt-pred-weighted", RegressionEvalStatsPlotHeatmapGroundTruthPredictions(weighted=True))
205 self.add_plot("scatter-gt-pred", RegressionEvalStatsPlotScatterGroundTruthPredictions())
208class ClassificationEvalStatsPlotCollector(EvalStatsPlotCollector[RegressionEvalStats, RegressionEvalStatsPlot]):
209 def __init__(self):
210 super().__init__()
211 self.add_plot("confusion-matrix-rel", ClassificationEvalStatsPlotConfusionMatrix(normalise=True))
212 self.add_plot("confusion-matrix-abs", ClassificationEvalStatsPlotConfusionMatrix(normalise=False))
213 # the plots below apply to the binary case only (skipped for non-binary case)
214 self.add_plot("precision-recall", ClassificationEvalStatsPlotPrecisionRecall())
215 self.add_plot("threshold-precision-recall", ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall())
216 self.add_plot("threshold-counts", ClassificationEvalStatsPlotProbabilityThresholdCounts())
219class ModelEvaluation(ABC, Generic[TModel, TEvaluator, TEvalData, TCrossValidator, TCrossValData, TEvalStats]):
220 """
221 Utility class for the evaluation of models based on a dataset
222 """
223 def __init__(self, io_data: InputOutputData,
224 eval_stats_plot_collector: Union[RegressionEvalStatsPlotCollector, ClassificationEvalStatsPlotCollector],
225 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams,
226 Dict[str, Any]]] = None,
227 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
228 test_io_data: Optional[InputOutputData] = None):
229 """
230 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split
231 into training and test data according to the rules specified by `evaluator_params`.
232 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is
233 taken to be the test data when creating evaluators for simple (single-split) evaluation.
234 :param eval_stats_plot_collector: a collector for plots generated from evaluation stats objects
235 :param evaluator_params: parameters with which to instantiate evaluators
236 :param cross_validator_params: parameters with which to instantiate cross-validators
237 :param test_io_data: optional test data (see `io_data`)
238 """
239 if cross_validator_params is None:
240 cross_validator_params = VectorModelCrossValidatorParams(folds=5)
241 self.evaluator_params = evaluator_params
242 self.cross_validator_params = cross_validator_params
243 self.io_data = io_data
244 self.test_io_data = test_io_data
245 self.eval_stats_plot_collector = eval_stats_plot_collector
247 def create_evaluator(self, model: TModel = None, is_regression: bool = None) -> TEvaluator:
248 """
249 Creates an evaluator holding the current input-output data
251 :param model: the model for which to create an evaluator (just for reading off regression or classification,
252 the resulting evaluator will work on other models as well)
253 :param is_regression: whether to create a regression model evaluator. Either this or model have to be specified
254 :return: an evaluator
255 """
256 return create_vector_model_evaluator(self.io_data, model=model, is_regression=is_regression, test_data=self.test_io_data,
257 params=self.evaluator_params)
259 def create_cross_validator(self, model: TModel = None, is_regression: bool = None) -> TCrossValidator:
260 """
261 Creates a cross-validator holding the current input-output data
263 :param model: the model for which to create a cross-validator (just for reading off regression or classification,
264 the resulting evaluator will work on other models as well)
265 :param is_regression: whether to create a regression model cross-validator. Either this or model have to be specified
266 :return: an evaluator
267 """
268 return create_vector_model_cross_validator(self.io_data, model=model, is_regression=is_regression,
269 params=self.cross_validator_params)
271 def perform_simple_evaluation(self, model: TModel,
272 create_plots=True, show_plots=False,
273 log_results=True,
274 result_writer: ResultWriter = None,
275 additional_evaluation_on_training_data=False,
276 fit_model=True, write_eval_stats=False,
277 tracked_experiment: TrackedExperiment = None,
278 evaluator: Optional[TEvaluator] = None) -> TEvalData:
280 if show_plots and not create_plots:
281 raise ValueError("showPlots=True requires createPlots=True")
282 result_writer = self._result_writer_for_model(result_writer, model)
283 if evaluator is None:
284 evaluator = self.create_evaluator(model)
285 if tracked_experiment is not None:
286 evaluator.set_tracked_experiment(tracked_experiment)
287 log.info(f"Evaluating {model} via {evaluator}")
289 def gather_results(result_data: VectorModelEvaluationData, res_writer, subtitle_prefix=""):
290 str_eval_results = ""
291 for predictedVarName in result_data.predicted_var_names:
292 eval_stats = result_data.get_eval_stats(predictedVarName)
293 str_eval_result = str(eval_stats)
294 if log_results:
295 log.info(f"{subtitle_prefix}Evaluation results for {predictedVarName}: {str_eval_result}")
296 str_eval_results += predictedVarName + ": " + str_eval_result + "\n"
297 if write_eval_stats and res_writer is not None:
298 res_writer.write_pickle(f"eval-stats-{predictedVarName}", eval_stats)
299 str_eval_results += f"\n\n{pretty_string_repr(model)}"
300 if res_writer is not None:
301 res_writer.write_text_file("evaluator-results", str_eval_results)
302 if create_plots:
303 with TrackingContext.from_optional_experiment(tracked_experiment, model=model) as trackingContext:
304 self.create_plots(result_data, show_plots=show_plots, result_writer=res_writer,
305 subtitle_prefix=subtitle_prefix, tracking_context=trackingContext)
307 eval_result_data = evaluator.eval_model(model, fit=fit_model)
308 gather_results(eval_result_data, result_writer)
309 if additional_evaluation_on_training_data:
310 eval_result_data_train = evaluator.eval_model(model, on_training_data=True, track=False)
311 additional_result_writer = result_writer.child_with_added_prefix("onTrain-") if result_writer is not None else None
312 gather_results(eval_result_data_train, additional_result_writer, subtitle_prefix="[onTrain] ")
313 return eval_result_data
315 @staticmethod
316 def _result_writer_for_model(result_writer: Optional[ResultWriter], model: TModel) -> Optional[ResultWriter]:
317 if result_writer is None:
318 return None
319 return result_writer.child_with_added_prefix(model.get_name() + "_")
321 def perform_cross_validation(self, model: TModel, show_plots=False, log_results=True, result_writer: Optional[ResultWriter] = None,
322 tracked_experiment: TrackedExperiment = None, cross_validator: Optional[TCrossValidator] = None) -> TCrossValData:
323 """
324 Evaluates the given model via cross-validation
326 :param model: the model to evaluate
327 :param show_plots: whether to show plots that visualise evaluation results (combining all folds)
328 :param log_results: whether to log evaluation results
329 :param result_writer: a writer with which to store text files and plots. The evaluated model's name is added to each filename
330 automatically
331 :param tracked_experiment: a tracked experiment with which results shall be associated
332 :return: cross-validation result data
333 :param cross_validator: the cross-validator to apply; if None, a suitable cross-validator will be created
334 """
335 result_writer = self._result_writer_for_model(result_writer, model)
337 if cross_validator is None:
338 cross_validator = self.create_cross_validator(model)
339 if tracked_experiment is not None:
340 cross_validator.set_tracked_experiment(tracked_experiment)
342 cross_validation_data = cross_validator.eval_model(model)
344 agg_stats_by_var = {varName: cross_validation_data.get_eval_stats_collection(predicted_var_name=varName).agg_metrics_dict()
345 for varName in cross_validation_data.predicted_var_names}
346 df = pd.DataFrame.from_dict(agg_stats_by_var, orient="index")
348 str_eval_results = df.to_string()
349 if log_results:
350 log.info(f"Cross-validation results:\n{str_eval_results}")
351 if result_writer is not None:
352 result_writer.write_text_file("crossval-results", str_eval_results)
354 with TrackingContext.from_optional_experiment(tracked_experiment, model=model) as trackingContext:
355 self.create_plots(cross_validation_data, show_plots=show_plots, result_writer=result_writer,
356 tracking_context=trackingContext)
358 return cross_validation_data
360 def compare_models(self, models: Sequence[TModel], result_writer: Optional[ResultWriter] = None, use_cross_validation=False,
361 fit_models=True, write_individual_results=True, sort_column: Optional[str] = None, sort_ascending: bool = True,
362 sort_column_move_to_left=True,
363 also_include_unsorted_results: bool = False, also_include_cross_val_global_stats: bool = False,
364 visitors: Optional[Iterable["ModelComparisonVisitor"]] = None,
365 write_visitor_results=False, write_csv=False,
366 tracked_experiment: Optional[TrackedExperiment] = None) -> "ModelComparisonData":
367 """
368 Compares several models via simple evaluation or cross-validation
370 :param models: the models to compare
371 :param result_writer: a writer with which to store results of the comparison
372 :param use_cross_validation: whether to use cross-validation in order to evaluate models; if False, use a simple evaluation
373 on test data (single split)
374 :param fit_models: whether to fit models before evaluating them; this can only be False if useCrossValidation=False
375 :param write_individual_results: whether to write results files on each individual model (in addition to the comparison
376 summary)
377 :param sort_column: column/metric name by which to sort; the fact that the column names change when using cross-validation
378 (aggregation function names being added) should be ignored, simply pass the (unmodified) metric name
379 :param sort_ascending: whether to sort using `sortColumn` in ascending order
380 :param sort_column_move_to_left: whether to move the `sortColumn` (if any) to the very left
381 :param also_include_unsorted_results: whether to also include, for the case where the results are sorted, the unsorted table of
382 results in the results text
383 :param also_include_cross_val_global_stats: whether to also include, when using cross-validation, the evaluation metrics obtained
384 when combining the predictions from all folds into a single collection. Note that for classification models,
385 this may not always be possible (if the set of classes know to the model differs across folds)
386 :param visitors: visitors which may process individual results
387 :param write_visitor_results: whether to collect results from visitors (if any) after the comparison
388 :param write_csv: whether to write metrics table to CSV files
389 :param tracked_experiment: an experiment for tracking
390 :return: the comparison results
391 """
392 # collect model evaluation results
393 stats_list = []
394 result_by_model_name = {}
395 evaluator = None
396 cross_validator = None
397 for i, model in enumerate(models, start=1):
398 model_name = model.get_name()
399 log.info(f"Evaluating model {i}/{len(models)} named '{model_name}' ...")
400 if use_cross_validation:
401 if not fit_models:
402 raise ValueError("Cross-validation necessitates that models be trained several times; got fitModels=False")
403 if cross_validator is None:
404 cross_validator = self.create_cross_validator(model)
405 cross_val_data = self.perform_cross_validation(model, result_writer=result_writer if write_individual_results else None,
406 cross_validator=cross_validator, tracked_experiment=tracked_experiment)
407 model_result = ModelComparisonData.Result(cross_validation_data=cross_val_data)
408 result_by_model_name[model_name] = model_result
409 eval_stats_collection = cross_val_data.get_eval_stats_collection()
410 stats_dict = eval_stats_collection.agg_metrics_dict()
411 else:
412 if evaluator is None:
413 evaluator = self.create_evaluator(model)
414 eval_data = self.perform_simple_evaluation(model, result_writer=result_writer if write_individual_results else None,
415 fit_model=fit_models, evaluator=evaluator, tracked_experiment=tracked_experiment)
416 model_result = ModelComparisonData.Result(eval_data=eval_data)
417 result_by_model_name[model_name] = model_result
418 eval_stats = eval_data.get_eval_stats()
419 stats_dict = eval_stats.metrics_dict()
420 stats_dict["model_name"] = model_name
421 stats_list.append(stats_dict)
422 if visitors is not None:
423 for visitor in visitors:
424 visitor.visit(model_name, model_result)
425 results_df = pd.DataFrame(stats_list).set_index("model_name")
427 # compute results data frame with combined set of data points (for cross-validation only)
428 cross_val_combined_results_df = None
429 if use_cross_validation and also_include_cross_val_global_stats:
430 try:
431 rows = []
432 for model_name, result in result_by_model_name.items():
433 stats_dict = result.cross_validation_data.get_eval_stats_collection().get_global_stats().metrics_dict()
434 stats_dict["model_name"] = model_name
435 rows.append(stats_dict)
436 cross_val_combined_results_df = pd.DataFrame(rows).set_index("model_name")
437 except Exception as e:
438 log.error(f"Creation of global stats data frame from cross-validation folds failed: {e}")
440 def sorted_df(df, sort_col):
441 if sort_col is not None:
442 if sort_col not in df.columns:
443 alt_sort_col = f"mean[{sort_col}]"
444 if alt_sort_col in df.columns:
445 sort_col = alt_sort_col
446 else:
447 sort_col = None
448 log.warning(f"Requested sort column '{sort_col}' (or '{alt_sort_col}') not in list of columns {list(df.columns)}")
449 if sort_col is not None:
450 df = df.sort_values(sort_col, ascending=sort_ascending, inplace=False)
451 if sort_column_move_to_left:
452 df = df[[sort_col] + [c for c in df.columns if c != sort_col]]
453 return df
455 # write comparison results
456 title = "Model comparison results"
457 if use_cross_validation:
458 title += ", aggregated across folds"
459 sorted_results_df = sorted_df(results_df, sort_column)
460 str_results = f"{title}:\n{sorted_results_df.to_string()}"
461 if also_include_unsorted_results and sort_column is not None:
462 str_results += f"\n\n{title} (unsorted):\n{results_df.to_string()}"
463 sorted_cross_val_combined_results_df = None
464 if cross_val_combined_results_df is not None:
465 sorted_cross_val_combined_results_df = sorted_df(cross_val_combined_results_df, sort_column)
466 str_results += f"\n\nModel comparison results based on combined set of data points from all folds:\n" \
467 f"{sorted_cross_val_combined_results_df.to_string()}"
468 log.info(str_results)
469 if result_writer is not None:
470 suffix = "crossval" if use_cross_validation else "simple-eval"
471 str_results += "\n\n" + "\n\n".join([f"{model.get_name()} = {model.pprints()}" for model in models])
472 result_writer.write_text_file(f"model-comparison-results-{suffix}", str_results)
473 if write_csv:
474 result_writer.write_data_frame_csv_file(f"model-comparison-metrics-{suffix}", sorted_results_df)
475 if sorted_cross_val_combined_results_df is not None:
476 result_writer.write_data_frame_csv_file(f"model-comparison-metrics-{suffix}-combined",
477 sorted_cross_val_combined_results_df)
479 # write visitor results
480 if visitors is not None and write_visitor_results:
481 result_collector = EvaluationResultCollector(show_plots=False, result_writer=result_writer)
482 for visitor in visitors:
483 visitor.collect_results(result_collector)
485 return ModelComparisonData(results_df, result_by_model_name, evaluator=evaluator, cross_validator=cross_validator)
487 def compare_models_cross_validation(self, models: Sequence[TModel],
488 result_writer: Optional[ResultWriter] = None) -> "ModelComparisonData":
489 """
490 Compares several models via cross-validation
492 :param models: the models to compare
493 :param result_writer: a writer with which to store results of the comparison
494 :return: the comparison results
495 """
496 return self.compare_models(models, result_writer=result_writer, use_cross_validation=True)
498 def create_plots(self, data: Union[TEvalData, TCrossValData], show_plots=True, result_writer: Optional[ResultWriter] = None,
499 subtitle_prefix: str = "", tracking_context: Optional[TrackingContext] = None):
500 """
501 Creates default plots that visualise the results in the given evaluation data
503 :param data: the evaluation data for which to create the default plots
504 :param show_plots: whether to show plots
505 :param result_writer: if not None, plots will be written using this writer
506 :param subtitle_prefix: a prefix to add to the subtitle (which itself is the model name)
507 :param tracking_context: the experiment tracking context
508 """
509 result_collector = EvaluationResultCollector(show_plots=show_plots, result_writer=result_writer,
510 tracking_context=tracking_context)
511 if result_collector.is_plot_creation_enabled():
512 self._create_plots(data, result_collector, subtitle=subtitle_prefix + data.model_name)
514 def _create_plots(self, data: Union[TEvalData, TCrossValData], result_collector: EvaluationResultCollector, subtitle=None):
516 def create_plots(pred_var_name, res_collector, subt):
517 if isinstance(data, VectorModelCrossValidationData):
518 eval_stats = data.get_eval_stats_collection(predicted_var_name=pred_var_name).get_global_stats()
519 elif isinstance(data, VectorModelEvaluationData):
520 eval_stats = data.get_eval_stats(predicted_var_name=pred_var_name)
521 else:
522 raise ValueError(f"Unexpected argument: data={data}")
523 return self._create_eval_stats_plots(eval_stats, res_collector, subtitle=subt)
525 predicted_var_names = data.predicted_var_names
526 if len(predicted_var_names) == 1:
527 create_plots(predicted_var_names[0], result_collector, subtitle)
528 else:
529 for predictedVarName in predicted_var_names:
530 create_plots(predictedVarName, result_collector.child(predictedVarName + "-"), f"{predictedVarName}, {subtitle}")
532 def _create_eval_stats_plots(self, eval_stats: TEvalStats, result_collector: EvaluationResultCollector, subtitle=None):
533 """
534 :param eval_stats: the evaluation results for which to create plots
535 :param result_collector: the collector to which all plots are to be passed
536 :param subtitle: the subtitle to use for generated plots (if any)
537 """
538 self.eval_stats_plot_collector.create_plots(eval_stats, subtitle, result_collector)
541class RegressionModelEvaluation(ModelEvaluation[VectorRegressionModel, VectorRegressionModelEvaluator, VectorRegressionModelEvaluationData,
542 VectorRegressionModelCrossValidator, VectorRegressionModelCrossValidationData, RegressionEvalStats]):
543 def __init__(self, io_data: InputOutputData,
544 evaluator_params: Optional[Union[RegressionEvaluatorParams, Dict[str, Any]]] = None,
545 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
546 test_io_data: Optional[InputOutputData] = None):
547 """
548 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split
549 into training and test data according to the rules specified by `evaluator_params`.
550 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is
551 taken to be the test data when creating evaluators for simple (single-split) evaluation.
552 :param evaluator_params: parameters with which to instantiate evaluators
553 :param cross_validator_params: parameters with which to instantiate cross-validators
554 :param test_io_data: optional test data (see `io_data`)
555 """
556 super().__init__(io_data, eval_stats_plot_collector=RegressionEvalStatsPlotCollector(), evaluator_params=evaluator_params,
557 cross_validator_params=cross_validator_params, test_io_data=test_io_data)
560class ClassificationModelEvaluation(ModelEvaluation[VectorClassificationModel, VectorClassificationModelEvaluator,
561 VectorClassificationModelEvaluationData, VectorClassificationModelCrossValidator, VectorClassificationModelCrossValidationData,
562 ClassificationEvalStats]):
563 def __init__(self, io_data: InputOutputData,
564 evaluator_params: Optional[Union[ClassificationEvaluatorParams, Dict[str, Any]]] = None,
565 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
566 test_io_data: Optional[InputOutputData] = None):
567 """
568 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split
569 into training and test data according to the rules specified by `evaluator_params`.
570 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is
571 taken to be the test data when creating evaluators for simple (single-split) evaluation.
572 :param evaluator_params: parameters with which to instantiate evaluators
573 :param cross_validator_params: parameters with which to instantiate cross-validators
574 :param test_io_data: optional test data (see `io_data`)
575 """
576 super().__init__(io_data, eval_stats_plot_collector=ClassificationEvalStatsPlotCollector(), evaluator_params=evaluator_params,
577 cross_validator_params=cross_validator_params, test_io_data=test_io_data)
580class MultiDataModelEvaluation:
581 def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset",
582 meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None,
583 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None,
584 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
585 test_io_data_dict: Optional[Dict[str, Optional[InputOutputData]]] = None):
586 """
587 :param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models.
588 For evaluation or cross-validation, these datasets will usually be split according to the rules
589 specified by `evaluator_params or `cross_validator_params`. An exception is the case where
590 explicit test data sets are specified by passing `test_io_data_dict`. Then, for these data
591 sets, the io_data will not be split for evaluation, but the test_io_data will be used instead.
592 :param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames
593 :param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps
594 from a column name to a value and which is to be used to extend the result data frames containing per-dataset results
595 :param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False)
596 :param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True)
597 :param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation or to None.
598 Entries with non-None values will be used for evaluation of the models that were trained on the respective io_data_dict.
599 If passed, the keys need to be a superset of io_data_dict's keys (note that the values may be None, e.g.
600 if you want to use test data sets for some entries, and splitting of the io_data for others).
601 If not None, cross-validation cannot be used when calling ``compare_models``.
602 """
603 if test_io_data_dict is not None:
604 missing_keys = set(io_data_dict).difference(test_io_data_dict)
605 if len(missing_keys) > 0:
606 raise ValueError(
607 "If test_io_data_dict is passed, its keys must be a superset of the io_data_dict's keys."
608 f"However, found missing_keys: {missing_keys}")
609 self.io_data_dict = io_data_dict
610 self.test_io_data_dict = test_io_data_dict
612 self.key_name = key_name
613 self.evaluator_params = evaluator_params
614 self.cross_validator_params = cross_validator_params
615 if meta_data_dict is not None:
616 self.meta_df = pd.DataFrame(meta_data_dict.values(), index=meta_data_dict.keys())
617 else:
618 self.meta_df = None
620 def compare_models(self,
621 model_factories: Sequence[Callable[[], Union[VectorRegressionModel, VectorClassificationModel]]],
622 use_cross_validation=False,
623 result_writer: Optional[ResultWriter] = None,
624 write_per_dataset_results=False,
625 write_csvs=False,
626 column_name_for_model_ranking: str = None,
627 rank_max=True,
628 add_combined_eval_stats=False,
629 create_metric_distribution_plots=True,
630 create_combined_eval_stats_plots=False,
631 distribution_plots_cdf = True,
632 distribution_plots_cdf_complementary = False,
633 visitors: Optional[Iterable["ModelComparisonVisitor"]] = None) \
634 -> Union["RegressionMultiDataModelComparisonData", "ClassificationMultiDataModelComparisonData"]:
635 """
636 :param model_factories: a sequence of factory functions for the creation of models to evaluate; every factory must result
637 in a model with a fixed model name (otherwise results cannot be correctly aggregated)
638 :param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation.
639 This can only be used if the instance's ``test_io_data_dict`` is None.
640 :param result_writer: a writer with which to store results; if None, results are not stored
641 :param write_per_dataset_results: whether to use resultWriter (if not None) in order to generate detailed results for each
642 dataset in a subdirectory named according to the name of the dataset
643 :param write_csvs: whether to write metrics table to CSV files
644 :param column_name_for_model_ranking: column name to use for ranking models
645 :param rank_max: if true, use max for ranking, else min
646 :param add_combined_eval_stats: whether to also report, for each model, evaluation metrics on the combined set data points from
647 all EvalStats objects.
648 Note that for classification, this is only possible if all individual experiments use the same set of class labels.
649 :param create_metric_distribution_plots: whether to create, for each model, plots of the distribution of each metric across the
650 datasets (applies only if result_writer is not None)
651 :param create_combined_eval_stats_plots: whether to combine, for each type of model, the EvalStats objects from the individual
652 experiments into a single objects that holds all results and use it to create plots reflecting the overall result (applies only
653 if resultWriter is not None).
654 Note that for classification, this is only possible if all individual experiments use the same set of class labels.
655 :param distribution_plots_cdf: whether to create CDF plots for the metric distributions. Applies only if
656 create_metric_distribution_plots is True and result_writer is not None.
657 :param distribution_plots_cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that
658 distribution_plots_cdf is True.
659 :param visitors: visitors which may process individual results. Plots generated by visitors are created/collected at the end of the
660 comparison.
661 :return: an object containing the full comparison results
662 """
663 if self.test_io_data_dict and use_cross_validation:
664 raise ValueError("Cannot use cross-validation when `test_io_data_dict` is specified")
666 all_results_df = pd.DataFrame()
667 eval_stats_by_model_name = defaultdict(list)
668 results_by_model_name: Dict[str, List[ModelComparisonData.Result]] = defaultdict(list)
669 is_regression = None
670 plot_collector: Optional[EvalStatsPlotCollector] = None
671 model_names = None
672 model_name_to_string_repr = None
674 for i, (key, inputOutputData) in enumerate(self.io_data_dict.items(), start=1):
675 log.info(f"Evaluating models for data set #{i}/{len(self.io_data_dict)}: {self.key_name}={key}")
676 models = [f() for f in model_factories]
678 current_model_names = [model.get_name() for model in models]
679 if model_names is None:
680 model_names = current_model_names
681 elif model_names != current_model_names:
682 log.warning(f"Model factories do not produce fixed names; use model.withName to name your models. "
683 f"Got {current_model_names}, previously got {model_names}")
685 if is_regression is None:
686 models_are_regression = [model.is_regression_model() for model in models]
687 if all(models_are_regression):
688 is_regression = True
689 elif not any(models_are_regression):
690 is_regression = False
691 else:
692 raise ValueError("The models have to be either all regression models or all classification, not a mixture")
694 test_io_data = self.test_io_data_dict[key] if self.test_io_data_dict is not None else None
695 ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params,
696 cross_validator_params=self.cross_validator_params, test_io_data=test_io_data)
698 if plot_collector is None:
699 plot_collector = ev.eval_stats_plot_collector
701 # compute data frame with results for current data set
702 if write_per_dataset_results and result_writer is not None:
703 child_result_writer = result_writer.child_for_subdirectory(key)
704 else:
705 child_result_writer = None
706 comparison_data = ev.compare_models(models, use_cross_validation=use_cross_validation, result_writer=child_result_writer,
707 visitors=visitors, write_visitor_results=False)
708 df = comparison_data.results_df
710 # augment data frame
711 df[self.key_name] = key
712 df["model_name"] = df.index
713 df = df.reset_index(drop=True)
715 # collect eval stats objects by model name
716 for modelName, result in comparison_data.result_by_model_name.items():
717 if use_cross_validation:
718 eval_stats = result.cross_validation_data.get_eval_stats_collection().get_global_stats()
719 else:
720 eval_stats = result.eval_data.get_eval_stats()
721 eval_stats_by_model_name[modelName].append(eval_stats)
722 results_by_model_name[modelName].append(result)
724 all_results_df = pd.concat((all_results_df, df))
726 if model_name_to_string_repr is None:
727 model_name_to_string_repr = {model.get_name(): model.pprints() for model in models}
729 if self.meta_df is not None:
730 all_results_df = all_results_df.join(self.meta_df, on=self.key_name, how="left")
732 str_all_results = f"All results:\n{all_results_df.to_string()}"
733 log.info(str_all_results)
735 # create mean result by model, removing any metrics/columns that produced NaN values
736 # (because the mean would be computed without them, skipna parameter unsupported)
737 all_results_grouped = all_results_df.drop(columns=self.key_name).dropna(axis=1).groupby("model_name")
738 mean_results_df: pd.DataFrame = all_results_grouped.mean()
739 for colName in [column_name_for_model_ranking, f"mean[{column_name_for_model_ranking}]"]:
740 if colName in mean_results_df:
741 mean_results_df.sort_values(column_name_for_model_ranking, inplace=True, ascending=not rank_max)
742 break
743 str_mean_results = f"Mean results (averaged across {len(self.io_data_dict)} data sets):\n{mean_results_df.to_string()}"
744 log.info(str_mean_results)
746 def iter_combined_eval_stats_from_all_data_sets():
747 for model_name, evalStatsList in eval_stats_by_model_name.items():
748 if is_regression:
749 ev_stats = RegressionEvalStatsCollection(evalStatsList).get_global_stats()
750 else:
751 ev_stats = ClassificationEvalStatsCollection(evalStatsList).get_global_stats()
752 yield model_name, ev_stats
754 # create further aggregations
755 agg_dfs = []
756 for op_name, agg_fn in [("mean", lambda x: x.mean()), ("std", lambda x: x.std()), ("min", lambda x: x.min()),
757 ("max", lambda x: x.max())]:
758 agg_df = agg_fn(all_results_grouped)
759 agg_df.columns = [f"{op_name}[{c}]" for c in agg_df.columns]
760 agg_dfs.append(agg_df)
761 further_aggs_df = pd.concat(agg_dfs, axis=1)
762 further_aggs_df = further_aggs_df.loc[mean_results_df.index] # apply same sort order (index is model_name)
763 column_order = functools.reduce(lambda a, b: a + b, [list(t) for t in zip(*[df.columns for df in agg_dfs])])
764 further_aggs_df = further_aggs_df[column_order]
765 str_further_aggs = f"Further aggregations:\n{further_aggs_df.to_string()}"
766 log.info(str_further_aggs)
768 # combined eval stats from all datasets (per model)
769 str_combined_eval_stats = ""
770 if add_combined_eval_stats:
771 rows = []
772 for modelName, eval_stats in iter_combined_eval_stats_from_all_data_sets():
773 rows.append({"model_name": modelName, **eval_stats.metrics_dict()})
774 combined_stats_df = pd.DataFrame(rows)
775 combined_stats_df.set_index("model_name", drop=True, inplace=True)
776 combined_stats_df = combined_stats_df.loc[mean_results_df.index] # apply same sort order (index is model_name)
777 str_combined_eval_stats = f"Results on combined test data from all data sets:\n{combined_stats_df.to_string()}\n\n"
778 log.info(str_combined_eval_stats)
780 if result_writer is not None:
781 comparison_content = str_mean_results + "\n\n" + str_further_aggs + "\n\n" + str_combined_eval_stats + str_all_results
782 comparison_content += "\n\nModels [example instance]:\n\n"
783 comparison_content += "\n\n".join(f"{name} = {s}" for name, s in model_name_to_string_repr.items())
784 result_writer.write_text_file("model-comparison-results", comparison_content)
785 if write_csvs:
786 result_writer.write_data_frame_csv_file("all-results", all_results_df)
787 result_writer.write_data_frame_csv_file("mean-results", mean_results_df)
789 # create plots from combined data for each model
790 if create_combined_eval_stats_plots:
791 for modelName, eval_stats in iter_combined_eval_stats_from_all_data_sets():
792 child_result_writer = result_writer.child_with_added_prefix(modelName + "_") if result_writer is not None else None
793 result_collector = EvaluationResultCollector(show_plots=False, result_writer=child_result_writer)
794 plot_collector.create_plots(eval_stats, subtitle=modelName, result_collector=result_collector)
796 # collect results from visitors (if any)
797 result_collector = EvaluationResultCollector(show_plots=False, result_writer=result_writer)
798 if visitors is not None:
799 for visitor in visitors:
800 visitor.collect_results(result_collector)
802 # create result
803 dataset_names = list(self.io_data_dict.keys())
804 if is_regression:
805 mdmc_data = RegressionMultiDataModelComparisonData(all_results_df, mean_results_df, further_aggs_df, eval_stats_by_model_name,
806 results_by_model_name, dataset_names, model_name_to_string_repr)
807 else:
808 mdmc_data = ClassificationMultiDataModelComparisonData(all_results_df, mean_results_df, further_aggs_df,
809 eval_stats_by_model_name, results_by_model_name, dataset_names, model_name_to_string_repr)
811 # plot distributions
812 if create_metric_distribution_plots and result_writer is not None:
813 mdmc_data.create_distribution_plots(result_writer, cdf=distribution_plots_cdf,
814 cdf_complementary=distribution_plots_cdf_complementary)
816 return mdmc_data
819class ModelComparisonData:
820 @dataclass
821 class Result:
822 eval_data: Union[VectorClassificationModelEvaluationData, VectorRegressionModelEvaluationData] = None
823 cross_validation_data: Union[VectorClassificationModelCrossValidationData, VectorRegressionModelCrossValidationData] = None
825 def iter_evaluation_data(self) -> Iterator[Union[VectorClassificationModelEvaluationData, VectorRegressionModelEvaluationData]]:
826 if self.eval_data is not None:
827 yield self.eval_data
828 if self.cross_validation_data is not None:
829 yield from self.cross_validation_data.eval_data_list
831 def __init__(self, results_df: pd.DataFrame, results_by_model_name: Dict[str, Result], evaluator: Optional[VectorModelEvaluator] = None,
832 cross_validator: Optional[VectorModelCrossValidator] = None):
833 self.results_df = results_df
834 self.result_by_model_name = results_by_model_name
835 self.evaluator = evaluator
836 self.cross_validator = cross_validator
838 def get_best_model_name(self, metric_name: str) -> str:
839 idx = np.argmax(self.results_df[metric_name])
840 return self.results_df.index[idx]
842 def get_best_model(self, metric_name: str) -> Union[VectorClassificationModel, VectorRegressionModel, VectorModelBase]:
843 result = self.result_by_model_name[self.get_best_model_name(metric_name)]
844 if result.eval_data is None:
845 raise ValueError("The best model is not well-defined when using cross-validation")
846 return result.eval_data.model
849class ModelComparisonVisitor(ABC):
850 @abstractmethod
851 def visit(self, model_name: str, result: ModelComparisonData.Result):
852 pass
854 @abstractmethod
855 def collect_results(self, result_collector: EvaluationResultCollector) -> None:
856 """
857 Collects results (such as figures) at the end of the model comparison, based on the results collected
859 :param result_collector: the collector to which figures are to be added
860 """
861 pass
864class ModelComparisonVisitorAggregatedFeatureImportance(ModelComparisonVisitor):
865 """
866 During a model comparison, computes aggregated feature importance values for the model with the given name
867 """
868 def __init__(self, model_name: str, feature_agg_regex: Sequence[str] = (), write_figure=True, write_data_frame_csv=False):
869 r"""
870 :param model_name: the name of the model for which to compute the aggregated feature importance values
871 :param feature_agg_regex: a sequence of regular expressions describing which feature names to sum as one. Each regex must
872 contain exactly one group. If a regex matches a feature name, the feature importance will be summed under the key
873 of the matched group instead of the full feature name. For example, the regex r"(\w+)_\d+$" will cause "foo_1" and "foo_2"
874 to be summed under "foo" and similarly "bar_1" and "bar_2" to be summed under "bar".
875 """
876 self.model_name = model_name
877 self.agg_feature_importance = AggregatedFeatureImportance(feature_agg_reg_ex=feature_agg_regex)
878 self.write_figure = write_figure
879 self.write_data_frame_csv = write_data_frame_csv
881 def visit(self, model_name: str, result: ModelComparisonData.Result):
882 if model_name == self.model_name:
883 if result.cross_validation_data is not None:
884 models = result.cross_validation_data.trained_models
885 if models is not None:
886 for model in models:
887 self._collect(model)
888 else:
889 raise ValueError("Models were not returned in cross-validation results")
890 elif result.eval_data is not None:
891 self._collect(result.eval_data.model)
893 def _collect(self, model: Union[FeatureImportanceProvider, VectorModelBase]):
894 if not isinstance(model, FeatureImportanceProvider):
895 raise ValueError(f"Got model which does inherit from {FeatureImportanceProvider.__qualname__}: {model}")
896 self.agg_feature_importance.add(model.get_feature_importance_dict())
898 @deprecated("Use getFeatureImportance and create the plot using the returned object")
899 def plot_feature_importance(self) -> plt.Figure:
900 feature_importance_dict = self.agg_feature_importance.get_aggregated_feature_importance().get_feature_importance_dict()
901 return plot_feature_importance(feature_importance_dict, subtitle=self.model_name)
903 def get_feature_importance(self) -> FeatureImportance:
904 return self.agg_feature_importance.get_aggregated_feature_importance()
906 def collect_results(self, result_collector: EvaluationResultCollector):
907 feature_importance = self.get_feature_importance()
908 if self.write_figure:
909 result_collector.add_figure(f"{self.model_name}_feature-importance", feature_importance.plot())
910 if self.write_data_frame_csv:
911 result_collector.add_data_frame_csv_file(f"{self.model_name}_feature-importance", feature_importance.get_data_frame())
914class MultiDataModelComparisonData(Generic[TEvalStats, TEvalStatsCollection], ABC):
915 def __init__(self, all_results_df: pd.DataFrame,
916 mean_results_df: pd.DataFrame,
917 agg_results_df: pd.DataFrame,
918 eval_stats_by_model_name: Dict[str, List[TEvalStats]],
919 results_by_model_name: Dict[str, List[ModelComparisonData.Result]],
920 dataset_names: List[str],
921 model_name_to_string_repr: Dict[str, str]):
922 self.all_results_df = all_results_df
923 self.mean_results_df = mean_results_df
924 self.agg_results_df = agg_results_df
925 self.eval_stats_by_model_name = eval_stats_by_model_name
926 self.results_by_model_name = results_by_model_name
927 self.dataset_names = dataset_names
928 self.model_name_to_string_repr = model_name_to_string_repr
930 def get_model_names(self) -> List[str]:
931 return list(self.eval_stats_by_model_name.keys())
933 def get_model_description(self, model_name: str) -> str:
934 return self.model_name_to_string_repr[model_name]
936 def get_eval_stats_list(self, model_name: str) -> List[TEvalStats]:
937 return self.eval_stats_by_model_name[model_name]
939 @abstractmethod
940 def get_eval_stats_collection(self, model_name: str) -> TEvalStatsCollection:
941 pass
943 def iter_model_results(self, model_name: str) -> Iterator[Tuple[str, ModelComparisonData.Result]]:
944 results = self.results_by_model_name[model_name]
945 yield from zip(self.dataset_names, results)
947 def create_distribution_plots(self, result_writer: ResultWriter, cdf=True, cdf_complementary=False):
948 """
949 Creates plots of distributions of metrics across datasets for each model as a histogram, and additionally
950 any x-y plots (scatter plots & heat maps) for metrics that have associated paired metrics that were also computed
952 :param result_writer: the result writer
953 :param cdf: whether to additionally plot, for each distribution, the cumulative distribution function
954 :param cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that ``cdf`` is True
955 """
956 for modelName in self.get_model_names():
957 eval_stats_collection = self.get_eval_stats_collection(modelName)
958 for metricName in eval_stats_collection.get_metric_names():
959 # plot distribution
960 fig = eval_stats_collection.plot_distribution(metricName, subtitle=modelName, cdf=cdf, cdf_complementary=cdf_complementary)
961 result_writer.write_figure(f"{modelName}_dist-{metricName}", fig)
962 # scatter plot with paired metrics
963 metric: Metric = eval_stats_collection.get_metric_by_name(metricName)
964 for paired_metric in metric.get_paired_metrics():
965 if eval_stats_collection.has_metric(paired_metric):
966 fig = eval_stats_collection.plot_scatter(metric.name, paired_metric.name)
967 result_writer.write_figure(f"{modelName}_scatter-{metric.name}-{paired_metric.name}", fig)
968 fig = eval_stats_collection.plot_heat_map(metric.name, paired_metric.name)
969 result_writer.write_figure(f"{modelName}_heatmap-{metric.name}-{paired_metric.name}", fig)
972class ClassificationMultiDataModelComparisonData(MultiDataModelComparisonData[ClassificationEvalStats, ClassificationEvalStatsCollection]):
973 def get_eval_stats_collection(self, model_name: str):
974 return ClassificationEvalStatsCollection(self.get_eval_stats_list(model_name))
977class RegressionMultiDataModelComparisonData(MultiDataModelComparisonData[RegressionEvalStats, RegressionEvalStatsCollection]):
978 def get_eval_stats_collection(self, model_name: str):
979 return RegressionEvalStatsCollection(self.get_eval_stats_list(model_name))