Coverage for src/sensai/evaluation/eval_util.py: 21%
534 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1"""
2This module contains methods and classes that facilitate evaluation of different types of models. The suggested
3workflow for evaluation is to use these higher-level functionalities instead of instantiating
4the evaluation classes directly.
5"""
6import functools
7import logging
8from abc import ABC, abstractmethod
9from collections import defaultdict
10from dataclasses import dataclass
11from typing import Dict, Any, Union, Generic, TypeVar, Optional, Sequence, Callable, Set, Iterable, List, Iterator, Tuple
13import matplotlib.figure
14import matplotlib.pyplot as plt
15import numpy as np
16import pandas as pd
17import seaborn as sns
19from .crossval import VectorModelCrossValidationData, VectorRegressionModelCrossValidationData, \
20 VectorClassificationModelCrossValidationData, \
21 VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator, VectorModelCrossValidator, VectorModelCrossValidatorParams
22from .eval_stats import RegressionEvalStatsCollection, ClassificationEvalStatsCollection, RegressionEvalStatsPlotErrorDistribution, \
23 RegressionEvalStatsPlotHeatmapGroundTruthPredictions, RegressionEvalStatsPlotScatterGroundTruthPredictions, \
24 ClassificationEvalStatsPlotConfusionMatrix, ClassificationEvalStatsPlotPrecisionRecall, RegressionEvalStatsPlot, \
25 ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall, ClassificationEvalStatsPlotProbabilityThresholdCounts, \
26 Metric
27from .eval_stats.eval_stats_base import EvalStats, EvalStatsCollection, EvalStatsPlot
28from .eval_stats.eval_stats_classification import ClassificationEvalStats
29from .eval_stats.eval_stats_regression import RegressionEvalStats
30from .evaluator import VectorModelEvaluator, VectorModelEvaluationData, VectorRegressionModelEvaluator, \
31 VectorRegressionModelEvaluationData, VectorClassificationModelEvaluator, VectorClassificationModelEvaluationData, \
32 RegressionEvaluatorParams, ClassificationEvaluatorParams
33from ..data import InputOutputData
34from ..feature_importance import AggregatedFeatureImportance, FeatureImportanceProvider, plot_feature_importance, FeatureImportance
35from ..tracking import TrackedExperiment
36from ..tracking.tracking_base import TrackingContext
37from ..util.deprecation import deprecated
38from ..util.io import ResultWriter
39from ..util.string import pretty_string_repr
40from ..vector_model import VectorClassificationModel, VectorRegressionModel, VectorModel, VectorModelBase
42log = logging.getLogger(__name__)
44TModel = TypeVar("TModel", bound=VectorModel)
45TEvalStats = TypeVar("TEvalStats", bound=EvalStats)
46TEvalStatsPlot = TypeVar("TEvalStatsPlot", bound=EvalStatsPlot)
47TEvalStatsCollection = TypeVar("TEvalStatsCollection", bound=EvalStatsCollection)
48TEvaluator = TypeVar("TEvaluator", bound=VectorModelEvaluator)
49TCrossValidator = TypeVar("TCrossValidator", bound=VectorModelCrossValidator)
50TEvalData = TypeVar("TEvalData", bound=VectorModelEvaluationData)
51TCrossValData = TypeVar("TCrossValData", bound=VectorModelCrossValidationData)
54def _is_regression(model: Optional[VectorModel], is_regression: Optional[bool]) -> bool:
55 if model is None and is_regression is None or (model is not None and is_regression is not None):
56 raise ValueError("One of the two parameters have to be passed: model or isRegression")
58 if is_regression is None:
59 model: VectorModel
60 return model.is_regression_model()
61 return is_regression
64def create_vector_model_evaluator(data: InputOutputData, model: VectorModel = None,
65 is_regression: bool = None, params: Union[RegressionEvaluatorParams, ClassificationEvaluatorParams] = None,
66 test_data: Optional[InputOutputData] = None) \
67 -> Union[VectorRegressionModelEvaluator, VectorClassificationModelEvaluator]:
68 is_regression = _is_regression(model, is_regression)
69 if params is None:
70 if is_regression:
71 params = RegressionEvaluatorParams(fractional_split_test_fraction=0.2)
72 else:
73 params = ClassificationEvaluatorParams(fractional_split_test_fraction=0.2)
74 log.debug(f"No evaluator parameters specified, using default: {params}")
75 if is_regression:
76 return VectorRegressionModelEvaluator(data, test_data=test_data, params=params)
77 else:
78 return VectorClassificationModelEvaluator(data, test_data=test_data, params=params)
81def create_vector_model_cross_validator(data: InputOutputData,
82 model: VectorModel = None,
83 is_regression: bool = None,
84 params: Union[VectorModelCrossValidatorParams, Dict[str, Any]] = None) \
85 -> Union[VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator]:
86 if params is None:
87 raise ValueError("params must not be None")
88 cons = VectorRegressionModelCrossValidator if _is_regression(model, is_regression) else VectorClassificationModelCrossValidator
89 return cons(data, params=params)
92def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None,
93 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None,
94 cross_validator_params: Optional[Dict[str, Any]] = None, test_io_data: Optional[InputOutputData] = None) \
95 -> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]:
96 if _is_regression(model, is_regression):
97 return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)
98 else:
99 return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)
102def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2,
103 plot_target_distribution=False, compute_probabilities=True, normalize_plots=True, random_seed=60) -> TEvalData:
104 """
105 Evaluates the given model via a simple evaluation mechanism that uses a single split
107 :param model: the model to evaluate
108 :param io_data: data on which to evaluate
109 :param test_fraction: the fraction of the data to test on
110 :param plot_target_distribution: whether to plot the target values distribution in the entire dataset
111 :param compute_probabilities: only relevant if the model is a classifier
112 :param normalize_plots: whether to normalize plotted distributions such that the sum/integrate to 1
113 :param random_seed:
115 :return: the evaluation data
116 """
117 if plot_target_distribution:
118 title = "Distribution of target values in entire dataset"
119 fig = plt.figure(title)
121 output_distribution_series = io_data.outputs.iloc[:, 0]
122 log.info(f"Description of target column in training set: \n{output_distribution_series.describe()}")
123 if not model.is_regression_model():
124 output_distribution_series = output_distribution_series.value_counts(normalize=normalize_plots)
125 ax = sns.barplot(output_distribution_series.index, output_distribution_series.values)
126 ax.set_ylabel("%")
127 else:
128 ax = sns.distplot(output_distribution_series)
129 ax.set_ylabel("Probability density")
130 ax.set_title(title)
131 ax.set_xlabel("target value")
132 fig.show()
134 if model.is_regression_model():
135 evaluator_params = RegressionEvaluatorParams(fractional_split_test_fraction=test_fraction,
136 fractional_split_random_seed=random_seed)
137 else:
138 evaluator_params = ClassificationEvaluatorParams(fractional_split_test_fraction=test_fraction,
139 compute_probabilities=compute_probabilities, fractional_split_random_seed=random_seed)
140 ev = create_evaluation_util(io_data, model=model, evaluator_params=evaluator_params)
141 return ev.perform_simple_evaluation(model, show_plots=True, log_results=True)
144class EvaluationResultCollector:
145 def __init__(self, show_plots: bool = True, result_writer: Optional[ResultWriter] = None,
146 tracking_context: TrackingContext = None):
147 self.show_plots = show_plots
148 self.result_writer = result_writer
149 self.tracking_context = tracking_context
151 def is_plot_creation_enabled(self) -> bool:
152 return self.show_plots or self.result_writer is not None or self.tracking_context is not None
154 def add_figure(self, name: str, fig: matplotlib.figure.Figure):
155 if self.result_writer is not None:
156 self.result_writer.write_figure(name, fig, close_figure=False)
157 if self.tracking_context is not None:
158 self.tracking_context.track_figure(name, fig)
159 if not self.show_plots:
160 plt.close(fig)
162 def add_data_frame_csv_file(self, name: str, df: pd.DataFrame):
163 if self.result_writer is not None:
164 self.result_writer.write_data_frame_csv_file(name, df)
166 def child(self, added_filename_prefix):
167 result_writer = self.result_writer
168 if result_writer:
169 result_writer = result_writer.child_with_added_prefix(added_filename_prefix)
170 return self.__class__(show_plots=self.show_plots, result_writer=result_writer)
173class EvalStatsPlotCollector(Generic[TEvalStats, TEvalStatsPlot]):
174 def __init__(self):
175 self.plots: Dict[str, EvalStatsPlot] = {}
176 self.disabled_plots: Set[str] = set()
178 def add_plot(self, name: str, plot: EvalStatsPlot):
179 self.plots[name] = plot
181 def get_enabled_plots(self) -> List[str]:
182 return [p for p in self.plots if p not in self.disabled_plots]
184 def disable_plots(self, *names: str):
185 self.disabled_plots.update(names)
187 def create_plots(self, eval_stats: EvalStats, subtitle: str, result_collector: EvaluationResultCollector):
188 known_plots = set(self.plots.keys())
189 unknown_disabled_plots = self.disabled_plots.difference(known_plots)
190 if len(unknown_disabled_plots) > 0:
191 log.warning(f"Plots were disabled which are not registered: {unknown_disabled_plots}; known plots: {known_plots}")
192 for name, plot in self.plots.items():
193 if name not in self.disabled_plots:
194 fig = plot.create_figure(eval_stats, subtitle)
195 if fig is not None:
196 result_collector.add_figure(name, fig)
199class RegressionEvalStatsPlotCollector(EvalStatsPlotCollector[RegressionEvalStats, RegressionEvalStatsPlot]):
200 def __init__(self):
201 super().__init__()
202 self.add_plot("error-dist", RegressionEvalStatsPlotErrorDistribution())
203 self.add_plot("heatmap-gt-pred", RegressionEvalStatsPlotHeatmapGroundTruthPredictions())
204 self.add_plot("scatter-gt-pred", RegressionEvalStatsPlotScatterGroundTruthPredictions())
207class ClassificationEvalStatsPlotCollector(EvalStatsPlotCollector[RegressionEvalStats, RegressionEvalStatsPlot]):
208 def __init__(self):
209 super().__init__()
210 self.add_plot("confusion-matrix-rel", ClassificationEvalStatsPlotConfusionMatrix(normalise=True))
211 self.add_plot("confusion-matrix-abs", ClassificationEvalStatsPlotConfusionMatrix(normalise=False))
212 # the plots below apply to the binary case only (skipped for non-binary case)
213 self.add_plot("precision-recall", ClassificationEvalStatsPlotPrecisionRecall())
214 self.add_plot("threshold-precision-recall", ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall())
215 self.add_plot("threshold-counts", ClassificationEvalStatsPlotProbabilityThresholdCounts())
218class ModelEvaluation(ABC, Generic[TModel, TEvaluator, TEvalData, TCrossValidator, TCrossValData, TEvalStats]):
219 """
220 Utility class for the evaluation of models based on a dataset
221 """
222 def __init__(self, io_data: InputOutputData,
223 eval_stats_plot_collector: Union[RegressionEvalStatsPlotCollector, ClassificationEvalStatsPlotCollector],
224 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams,
225 Dict[str, Any]]] = None,
226 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
227 test_io_data: Optional[InputOutputData] = None):
228 """
229 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split
230 into training and test data according to the rules specified by `evaluator_params`.
231 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is
232 taken to be the test data when creating evaluators for simple (single-split) evaluation.
233 :param eval_stats_plot_collector: a collector for plots generated from evaluation stats objects
234 :param evaluator_params: parameters with which to instantiate evaluators
235 :param cross_validator_params: parameters with which to instantiate cross-validators
236 :param test_io_data: optional test data (see `io_data`)
237 """
238 if cross_validator_params is None:
239 cross_validator_params = VectorModelCrossValidatorParams(folds=5)
240 self.evaluator_params = evaluator_params
241 self.cross_validator_params = cross_validator_params
242 self.io_data = io_data
243 self.test_io_data = test_io_data
244 self.eval_stats_plot_collector = eval_stats_plot_collector
246 def create_evaluator(self, model: TModel = None, is_regression: bool = None) -> TEvaluator:
247 """
248 Creates an evaluator holding the current input-output data
250 :param model: the model for which to create an evaluator (just for reading off regression or classification,
251 the resulting evaluator will work on other models as well)
252 :param is_regression: whether to create a regression model evaluator. Either this or model have to be specified
253 :return: an evaluator
254 """
255 return create_vector_model_evaluator(self.io_data, model=model, is_regression=is_regression, test_data=self.test_io_data,
256 params=self.evaluator_params)
258 def create_cross_validator(self, model: TModel = None, is_regression: bool = None) -> TCrossValidator:
259 """
260 Creates a cross-validator holding the current input-output data
262 :param model: the model for which to create a cross-validator (just for reading off regression or classification,
263 the resulting evaluator will work on other models as well)
264 :param is_regression: whether to create a regression model cross-validator. Either this or model have to be specified
265 :return: an evaluator
266 """
267 return create_vector_model_cross_validator(self.io_data, model=model, is_regression=is_regression,
268 params=self.cross_validator_params)
270 def perform_simple_evaluation(self, model: TModel,
271 create_plots=True, show_plots=False,
272 log_results=True,
273 result_writer: ResultWriter = None,
274 additional_evaluation_on_training_data=False,
275 fit_model=True, write_eval_stats=False,
276 tracked_experiment: TrackedExperiment = None,
277 evaluator: Optional[TEvaluator] = None) -> TEvalData:
279 if show_plots and not create_plots:
280 raise ValueError("showPlots=True requires createPlots=True")
281 result_writer = self._result_writer_for_model(result_writer, model)
282 if evaluator is None:
283 evaluator = self.create_evaluator(model)
284 if tracked_experiment is not None:
285 evaluator.set_tracked_experiment(tracked_experiment)
286 log.info(f"Evaluating {model} via {evaluator}")
288 def gather_results(result_data: VectorModelEvaluationData, res_writer, subtitle_prefix=""):
289 str_eval_results = ""
290 for predictedVarName in result_data.predicted_var_names:
291 eval_stats = result_data.get_eval_stats(predictedVarName)
292 str_eval_result = str(eval_stats)
293 if log_results:
294 log.info(f"{subtitle_prefix}Evaluation results for {predictedVarName}: {str_eval_result}")
295 str_eval_results += predictedVarName + ": " + str_eval_result + "\n"
296 if write_eval_stats and res_writer is not None:
297 res_writer.write_pickle(f"eval-stats-{predictedVarName}", eval_stats)
298 str_eval_results += f"\n\n{pretty_string_repr(model)}"
299 if res_writer is not None:
300 res_writer.write_text_file("evaluator-results", str_eval_results)
301 if create_plots:
302 with TrackingContext.from_optional_experiment(tracked_experiment, model=model) as trackingContext:
303 self.create_plots(result_data, show_plots=show_plots, result_writer=res_writer,
304 subtitle_prefix=subtitle_prefix, tracking_context=trackingContext)
306 eval_result_data = evaluator.eval_model(model, fit=fit_model)
307 gather_results(eval_result_data, result_writer)
308 if additional_evaluation_on_training_data:
309 eval_result_data_train = evaluator.eval_model(model, on_training_data=True, track=False)
310 additional_result_writer = result_writer.child_with_added_prefix("onTrain-") if result_writer is not None else None
311 gather_results(eval_result_data_train, additional_result_writer, subtitle_prefix="[onTrain] ")
312 return eval_result_data
314 @staticmethod
315 def _result_writer_for_model(result_writer: Optional[ResultWriter], model: TModel) -> Optional[ResultWriter]:
316 if result_writer is None:
317 return None
318 return result_writer.child_with_added_prefix(model.get_name() + "_")
320 def perform_cross_validation(self, model: TModel, show_plots=False, log_results=True, result_writer: Optional[ResultWriter] = None,
321 tracked_experiment: TrackedExperiment = None, cross_validator: Optional[TCrossValidator] = None) -> TCrossValData:
322 """
323 Evaluates the given model via cross-validation
325 :param model: the model to evaluate
326 :param show_plots: whether to show plots that visualise evaluation results (combining all folds)
327 :param log_results: whether to log evaluation results
328 :param result_writer: a writer with which to store text files and plots. The evaluated model's name is added to each filename
329 automatically
330 :param tracked_experiment: a tracked experiment with which results shall be associated
331 :return: cross-validation result data
332 :param cross_validator: the cross-validator to apply; if None, a suitable cross-validator will be created
333 """
334 result_writer = self._result_writer_for_model(result_writer, model)
336 if cross_validator is None:
337 cross_validator = self.create_cross_validator(model)
338 if tracked_experiment is not None:
339 cross_validator.set_tracked_experiment(tracked_experiment)
341 cross_validation_data = cross_validator.eval_model(model)
343 agg_stats_by_var = {varName: cross_validation_data.get_eval_stats_collection(predicted_var_name=varName).agg_metrics_dict()
344 for varName in cross_validation_data.predicted_var_names}
345 df = pd.DataFrame.from_dict(agg_stats_by_var, orient="index")
347 str_eval_results = df.to_string()
348 if log_results:
349 log.info(f"Cross-validation results:\n{str_eval_results}")
350 if result_writer is not None:
351 result_writer.write_text_file("crossval-results", str_eval_results)
353 with TrackingContext.from_optional_experiment(tracked_experiment, model=model) as trackingContext:
354 self.create_plots(cross_validation_data, show_plots=show_plots, result_writer=result_writer,
355 tracking_context=trackingContext)
357 return cross_validation_data
359 def compare_models(self, models: Sequence[TModel], result_writer: Optional[ResultWriter] = None, use_cross_validation=False,
360 fit_models=True, write_individual_results=True, sort_column: Optional[str] = None, sort_ascending: bool = True,
361 sort_column_move_to_left=True,
362 also_include_unsorted_results: bool = False, also_include_cross_val_global_stats: bool = False,
363 visitors: Optional[Iterable["ModelComparisonVisitor"]] = None,
364 write_visitor_results=False, write_csv=False,
365 tracked_experiment: Optional[TrackedExperiment] = None) -> "ModelComparisonData":
366 """
367 Compares several models via simple evaluation or cross-validation
369 :param models: the models to compare
370 :param result_writer: a writer with which to store results of the comparison
371 :param use_cross_validation: whether to use cross-validation in order to evaluate models; if False, use a simple evaluation
372 on test data (single split)
373 :param fit_models: whether to fit models before evaluating them; this can only be False if useCrossValidation=False
374 :param write_individual_results: whether to write results files on each individual model (in addition to the comparison
375 summary)
376 :param sort_column: column/metric name by which to sort; the fact that the column names change when using cross-validation
377 (aggregation function names being added) should be ignored, simply pass the (unmodified) metric name
378 :param sort_ascending: whether to sort using `sortColumn` in ascending order
379 :param sort_column_move_to_left: whether to move the `sortColumn` (if any) to the very left
380 :param also_include_unsorted_results: whether to also include, for the case where the results are sorted, the unsorted table of
381 results in the results text
382 :param also_include_cross_val_global_stats: whether to also include, when using cross-validation, the evaluation metrics obtained
383 when combining the predictions from all folds into a single collection. Note that for classification models,
384 this may not always be possible (if the set of classes know to the model differs across folds)
385 :param visitors: visitors which may process individual results
386 :param write_visitor_results: whether to collect results from visitors (if any) after the comparison
387 :param write_csv: whether to write metrics table to CSV files
388 :param tracked_experiment: an experiment for tracking
389 :return: the comparison results
390 """
391 # collect model evaluation results
392 stats_list = []
393 result_by_model_name = {}
394 evaluator = None
395 cross_validator = None
396 for i, model in enumerate(models, start=1):
397 model_name = model.get_name()
398 log.info(f"Evaluating model {i}/{len(models)} named '{model_name}' ...")
399 if use_cross_validation:
400 if not fit_models:
401 raise ValueError("Cross-validation necessitates that models be trained several times; got fitModels=False")
402 if cross_validator is None:
403 cross_validator = self.create_cross_validator(model)
404 cross_val_data = self.perform_cross_validation(model, result_writer=result_writer if write_individual_results else None,
405 cross_validator=cross_validator, tracked_experiment=tracked_experiment)
406 model_result = ModelComparisonData.Result(cross_validation_data=cross_val_data)
407 result_by_model_name[model_name] = model_result
408 eval_stats_collection = cross_val_data.get_eval_stats_collection()
409 stats_dict = eval_stats_collection.agg_metrics_dict()
410 else:
411 if evaluator is None:
412 evaluator = self.create_evaluator(model)
413 eval_data = self.perform_simple_evaluation(model, result_writer=result_writer if write_individual_results else None,
414 fit_model=fit_models, evaluator=evaluator, tracked_experiment=tracked_experiment)
415 model_result = ModelComparisonData.Result(eval_data=eval_data)
416 result_by_model_name[model_name] = model_result
417 eval_stats = eval_data.get_eval_stats()
418 stats_dict = eval_stats.metrics_dict()
419 stats_dict["model_name"] = model_name
420 stats_list.append(stats_dict)
421 if visitors is not None:
422 for visitor in visitors:
423 visitor.visit(model_name, model_result)
424 results_df = pd.DataFrame(stats_list).set_index("model_name")
426 # compute results data frame with combined set of data points (for cross-validation only)
427 cross_val_combined_results_df = None
428 if use_cross_validation and also_include_cross_val_global_stats:
429 try:
430 rows = []
431 for model_name, result in result_by_model_name.items():
432 stats_dict = result.cross_validation_data.get_eval_stats_collection().get_global_stats().metrics_dict()
433 stats_dict["model_name"] = model_name
434 rows.append(stats_dict)
435 cross_val_combined_results_df = pd.DataFrame(rows).set_index("model_name")
436 except Exception as e:
437 log.error(f"Creation of global stats data frame from cross-validation folds failed: {e}")
439 def sorted_df(df, sort_col):
440 if sort_col is not None:
441 if sort_col not in df.columns:
442 alt_sort_col = f"mean[{sort_col}]"
443 if alt_sort_col in df.columns:
444 sort_col = alt_sort_col
445 else:
446 sort_col = None
447 log.warning(f"Requested sort column '{sort_col}' (or '{alt_sort_col}') not in list of columns {list(df.columns)}")
448 if sort_col is not None:
449 df = df.sort_values(sort_col, ascending=sort_ascending, inplace=False)
450 if sort_column_move_to_left:
451 df = df[[sort_col] + [c for c in df.columns if c != sort_col]]
452 return df
454 # write comparison results
455 title = "Model comparison results"
456 if use_cross_validation:
457 title += ", aggregated across folds"
458 sorted_results_df = sorted_df(results_df, sort_column)
459 str_results = f"{title}:\n{sorted_results_df.to_string()}"
460 if also_include_unsorted_results and sort_column is not None:
461 str_results += f"\n\n{title} (unsorted):\n{results_df.to_string()}"
462 sorted_cross_val_combined_results_df = None
463 if cross_val_combined_results_df is not None:
464 sorted_cross_val_combined_results_df = sorted_df(cross_val_combined_results_df, sort_column)
465 str_results += f"\n\nModel comparison results based on combined set of data points from all folds:\n" \
466 f"{sorted_cross_val_combined_results_df.to_string()}"
467 log.info(str_results)
468 if result_writer is not None:
469 suffix = "crossval" if use_cross_validation else "simple-eval"
470 str_results += "\n\n" + "\n\n".join([f"{model.get_name()} = {model.pprints()}" for model in models])
471 result_writer.write_text_file(f"model-comparison-results-{suffix}", str_results)
472 if write_csv:
473 result_writer.write_data_frame_csv_file(f"model-comparison-metrics-{suffix}", sorted_results_df)
474 if sorted_cross_val_combined_results_df is not None:
475 result_writer.write_data_frame_csv_file(f"model-comparison-metrics-{suffix}-combined",
476 sorted_cross_val_combined_results_df)
478 # write visitor results
479 if visitors is not None and write_visitor_results:
480 result_collector = EvaluationResultCollector(show_plots=False, result_writer=result_writer)
481 for visitor in visitors:
482 visitor.collect_results(result_collector)
484 return ModelComparisonData(results_df, result_by_model_name, evaluator=evaluator, cross_validator=cross_validator)
486 def compare_models_cross_validation(self, models: Sequence[TModel],
487 result_writer: Optional[ResultWriter] = None) -> "ModelComparisonData":
488 """
489 Compares several models via cross-validation
491 :param models: the models to compare
492 :param result_writer: a writer with which to store results of the comparison
493 :return: the comparison results
494 """
495 return self.compare_models(models, result_writer=result_writer, use_cross_validation=True)
497 def create_plots(self, data: Union[TEvalData, TCrossValData], show_plots=True, result_writer: Optional[ResultWriter] = None,
498 subtitle_prefix: str = "", tracking_context: Optional[TrackingContext] = None):
499 """
500 Creates default plots that visualise the results in the given evaluation data
502 :param data: the evaluation data for which to create the default plots
503 :param show_plots: whether to show plots
504 :param result_writer: if not None, plots will be written using this writer
505 :param subtitle_prefix: a prefix to add to the subtitle (which itself is the model name)
506 :param tracking_context: the experiment tracking context
507 """
508 result_collector = EvaluationResultCollector(show_plots=show_plots, result_writer=result_writer,
509 tracking_context=tracking_context)
510 if result_collector.is_plot_creation_enabled():
511 self._create_plots(data, result_collector, subtitle=subtitle_prefix + data.model_name)
513 def _create_plots(self, data: Union[TEvalData, TCrossValData], result_collector: EvaluationResultCollector, subtitle=None):
515 def create_plots(pred_var_name, res_collector, subt):
516 if isinstance(data, VectorModelCrossValidationData):
517 eval_stats = data.get_eval_stats_collection(predicted_var_name=pred_var_name).get_global_stats()
518 elif isinstance(data, VectorModelEvaluationData):
519 eval_stats = data.get_eval_stats(predicted_var_name=pred_var_name)
520 else:
521 raise ValueError(f"Unexpected argument: data={data}")
522 return self._create_eval_stats_plots(eval_stats, res_collector, subtitle=subt)
524 predicted_var_names = data.predicted_var_names
525 if len(predicted_var_names) == 1:
526 create_plots(predicted_var_names[0], result_collector, subtitle)
527 else:
528 for predictedVarName in predicted_var_names:
529 create_plots(predictedVarName, result_collector.child(predictedVarName + "-"), f"{predictedVarName}, {subtitle}")
531 def _create_eval_stats_plots(self, eval_stats: TEvalStats, result_collector: EvaluationResultCollector, subtitle=None):
532 """
533 :param eval_stats: the evaluation results for which to create plots
534 :param result_collector: the collector to which all plots are to be passed
535 :param subtitle: the subtitle to use for generated plots (if any)
536 """
537 self.eval_stats_plot_collector.create_plots(eval_stats, subtitle, result_collector)
540class RegressionModelEvaluation(ModelEvaluation[VectorRegressionModel, VectorRegressionModelEvaluator, VectorRegressionModelEvaluationData,
541 VectorRegressionModelCrossValidator, VectorRegressionModelCrossValidationData, RegressionEvalStats]):
542 def __init__(self, io_data: InputOutputData,
543 evaluator_params: Optional[Union[RegressionEvaluatorParams, Dict[str, Any]]] = None,
544 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
545 test_io_data: Optional[InputOutputData] = None):
546 """
547 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split
548 into training and test data according to the rules specified by `evaluator_params`.
549 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is
550 taken to be the test data when creating evaluators for simple (single-split) evaluation.
551 :param evaluator_params: parameters with which to instantiate evaluators
552 :param cross_validator_params: parameters with which to instantiate cross-validators
553 :param test_io_data: optional test data (see `io_data`)
554 """
555 super().__init__(io_data, eval_stats_plot_collector=RegressionEvalStatsPlotCollector(), evaluator_params=evaluator_params,
556 cross_validator_params=cross_validator_params, test_io_data=test_io_data)
559class ClassificationModelEvaluation(ModelEvaluation[VectorClassificationModel, VectorClassificationModelEvaluator,
560 VectorClassificationModelEvaluationData, VectorClassificationModelCrossValidator, VectorClassificationModelCrossValidationData,
561 ClassificationEvalStats]):
562 def __init__(self, io_data: InputOutputData,
563 evaluator_params: Optional[Union[ClassificationEvaluatorParams, Dict[str, Any]]] = None,
564 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
565 test_io_data: Optional[InputOutputData] = None):
566 """
567 :param io_data: the data set to use for evaluation. For evaluation purposes, this dataset usually will be split
568 into training and test data according to the rules specified by `evaluator_params`.
569 However, if `test_io_data` is specified, then this is taken to be the training data and `test_io_data` is
570 taken to be the test data when creating evaluators for simple (single-split) evaluation.
571 :param evaluator_params: parameters with which to instantiate evaluators
572 :param cross_validator_params: parameters with which to instantiate cross-validators
573 :param test_io_data: optional test data (see `io_data`)
574 """
575 super().__init__(io_data, eval_stats_plot_collector=ClassificationEvalStatsPlotCollector(), evaluator_params=evaluator_params,
576 cross_validator_params=cross_validator_params, test_io_data=test_io_data)
579class MultiDataModelEvaluation:
580 def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset",
581 meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None,
582 evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None,
583 cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
584 test_io_data_dict: Optional[Dict[str, Optional[InputOutputData]]] = None):
585 """
586 :param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models.
587 For evaluation or cross-validation, these datasets will usually be split according to the rules
588 specified by `evaluator_params or `cross_validator_params`. An exception is the case where
589 explicit test data sets are specified by passing `test_io_data_dict`. Then, for these data
590 sets, the io_data will not be split for evaluation, but the test_io_data will be used instead.
591 :param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames
592 :param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps
593 from a column name to a value and which is to be used to extend the result data frames containing per-dataset results
594 :param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False)
595 :param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True)
596 :param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation or to None.
597 Entries with non-None values will be used for evaluation of the models that were trained on the respective io_data_dict.
598 If passed, the keys need to be a superset of io_data_dict's keys (note that the values may be None, e.g.
599 if you want to use test data sets for some entries, and splitting of the io_data for others).
600 If not None, cross-validation cannot be used when calling ``compare_models``.
601 """
602 if test_io_data_dict is not None:
603 missing_keys = set(io_data_dict).difference(test_io_data_dict)
604 if len(missing_keys) > 0:
605 raise ValueError(
606 "If test_io_data_dict is passed, its keys must be a superset of the io_data_dict's keys."
607 f"However, found missing_keys: {missing_keys}")
608 self.io_data_dict = io_data_dict
609 self.test_io_data_dict = test_io_data_dict
611 self.key_name = key_name
612 self.evaluator_params = evaluator_params
613 self.cross_validator_params = cross_validator_params
614 if meta_data_dict is not None:
615 self.meta_df = pd.DataFrame(meta_data_dict.values(), index=meta_data_dict.keys())
616 else:
617 self.meta_df = None
619 def compare_models(self,
620 model_factories: Sequence[Callable[[], Union[VectorRegressionModel, VectorClassificationModel]]],
621 use_cross_validation=False,
622 result_writer: Optional[ResultWriter] = None,
623 write_per_dataset_results=False,
624 write_csvs=False,
625 column_name_for_model_ranking: str = None,
626 rank_max=True,
627 add_combined_eval_stats=False,
628 create_metric_distribution_plots=True,
629 create_combined_eval_stats_plots=False,
630 distribution_plots_cdf = True,
631 distribution_plots_cdf_complementary = False,
632 visitors: Optional[Iterable["ModelComparisonVisitor"]] = None) \
633 -> Union["RegressionMultiDataModelComparisonData", "ClassificationMultiDataModelComparisonData"]:
634 """
635 :param model_factories: a sequence of factory functions for the creation of models to evaluate; every factory must result
636 in a model with a fixed model name (otherwise results cannot be correctly aggregated)
637 :param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation.
638 This can only be used if the instance's ``test_io_data_dict`` is None.
639 :param result_writer: a writer with which to store results; if None, results are not stored
640 :param write_per_dataset_results: whether to use resultWriter (if not None) in order to generate detailed results for each
641 dataset in a subdirectory named according to the name of the dataset
642 :param write_csvs: whether to write metrics table to CSV files
643 :param column_name_for_model_ranking: column name to use for ranking models
644 :param rank_max: if true, use max for ranking, else min
645 :param add_combined_eval_stats: whether to also report, for each model, evaluation metrics on the combined set data points from
646 all EvalStats objects.
647 Note that for classification, this is only possible if all individual experiments use the same set of class labels.
648 :param create_metric_distribution_plots: whether to create, for each model, plots of the distribution of each metric across the
649 datasets (applies only if result_writer is not None)
650 :param create_combined_eval_stats_plots: whether to combine, for each type of model, the EvalStats objects from the individual
651 experiments into a single objects that holds all results and use it to create plots reflecting the overall result (applies only
652 if resultWriter is not None).
653 Note that for classification, this is only possible if all individual experiments use the same set of class labels.
654 :param distribution_plots_cdf: whether to create CDF plots for the metric distributions. Applies only if
655 create_metric_distribution_plots is True and result_writer is not None.
656 :param distribution_plots_cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that
657 distribution_plots_cdf is True.
658 :param visitors: visitors which may process individual results. Plots generated by visitors are created/collected at the end of the
659 comparison.
660 :return: an object containing the full comparison results
661 """
662 if self.test_io_data_dict and use_cross_validation:
663 raise ValueError("Cannot use cross-validation when `test_io_data_dict` is specified")
665 all_results_df = pd.DataFrame()
666 eval_stats_by_model_name = defaultdict(list)
667 results_by_model_name: Dict[str, List[ModelComparisonData.Result]] = defaultdict(list)
668 is_regression = None
669 plot_collector: Optional[EvalStatsPlotCollector] = None
670 model_names = None
671 model_name_to_string_repr = None
673 for i, (key, inputOutputData) in enumerate(self.io_data_dict.items(), start=1):
674 log.info(f"Evaluating models for data set #{i}/{len(self.io_data_dict)}: {self.key_name}={key}")
675 models = [f() for f in model_factories]
677 current_model_names = [model.get_name() for model in models]
678 if model_names is None:
679 model_names = current_model_names
680 elif model_names != current_model_names:
681 log.warning(f"Model factories do not produce fixed names; use model.withName to name your models. "
682 f"Got {current_model_names}, previously got {model_names}")
684 if is_regression is None:
685 models_are_regression = [model.is_regression_model() for model in models]
686 if all(models_are_regression):
687 is_regression = True
688 elif not any(models_are_regression):
689 is_regression = False
690 else:
691 raise ValueError("The models have to be either all regression models or all classification, not a mixture")
693 test_io_data = self.test_io_data_dict[key] if self.test_io_data_dict is not None else None
694 ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params,
695 cross_validator_params=self.cross_validator_params, test_io_data=test_io_data)
697 if plot_collector is None:
698 plot_collector = ev.eval_stats_plot_collector
700 # compute data frame with results for current data set
701 if write_per_dataset_results and result_writer is not None:
702 child_result_writer = result_writer.child_for_subdirectory(key)
703 else:
704 child_result_writer = None
705 comparison_data = ev.compare_models(models, use_cross_validation=use_cross_validation, result_writer=child_result_writer,
706 visitors=visitors, write_visitor_results=False)
707 df = comparison_data.results_df
709 # augment data frame
710 df[self.key_name] = key
711 df["model_name"] = df.index
712 df = df.reset_index(drop=True)
714 # collect eval stats objects by model name
715 for modelName, result in comparison_data.result_by_model_name.items():
716 if use_cross_validation:
717 eval_stats = result.cross_validation_data.get_eval_stats_collection().get_global_stats()
718 else:
719 eval_stats = result.eval_data.get_eval_stats()
720 eval_stats_by_model_name[modelName].append(eval_stats)
721 results_by_model_name[modelName].append(result)
723 all_results_df = pd.concat((all_results_df, df))
725 if model_name_to_string_repr is None:
726 model_name_to_string_repr = {model.get_name(): model.pprints() for model in models}
728 if self.meta_df is not None:
729 all_results_df = all_results_df.join(self.meta_df, on=self.key_name, how="left")
731 str_all_results = f"All results:\n{all_results_df.to_string()}"
732 log.info(str_all_results)
734 # create mean result by model, removing any metrics/columns that produced NaN values
735 # (because the mean would be computed without them, skipna parameter unsupported)
736 all_results_grouped = all_results_df.drop(columns=self.key_name).dropna(axis=1).groupby("model_name")
737 mean_results_df: pd.DataFrame = all_results_grouped.mean()
738 for colName in [column_name_for_model_ranking, f"mean[{column_name_for_model_ranking}]"]:
739 if colName in mean_results_df:
740 mean_results_df.sort_values(column_name_for_model_ranking, inplace=True, ascending=not rank_max)
741 break
742 str_mean_results = f"Mean results (averaged across {len(self.io_data_dict)} data sets):\n{mean_results_df.to_string()}"
743 log.info(str_mean_results)
745 def iter_combined_eval_stats_from_all_data_sets():
746 for model_name, evalStatsList in eval_stats_by_model_name.items():
747 if is_regression:
748 ev_stats = RegressionEvalStatsCollection(evalStatsList).get_global_stats()
749 else:
750 ev_stats = ClassificationEvalStatsCollection(evalStatsList).get_global_stats()
751 yield model_name, ev_stats
753 # create further aggregations
754 agg_dfs = []
755 for op_name, agg_fn in [("mean", lambda x: x.mean()), ("std", lambda x: x.std()), ("min", lambda x: x.min()),
756 ("max", lambda x: x.max())]:
757 agg_df = agg_fn(all_results_grouped)
758 agg_df.columns = [f"{op_name}[{c}]" for c in agg_df.columns]
759 agg_dfs.append(agg_df)
760 further_aggs_df = pd.concat(agg_dfs, axis=1)
761 further_aggs_df = further_aggs_df.loc[mean_results_df.index] # apply same sort order (index is model_name)
762 column_order = functools.reduce(lambda a, b: a + b, [list(t) for t in zip(*[df.columns for df in agg_dfs])])
763 further_aggs_df = further_aggs_df[column_order]
764 str_further_aggs = f"Further aggregations:\n{further_aggs_df.to_string()}"
765 log.info(str_further_aggs)
767 # combined eval stats from all datasets (per model)
768 str_combined_eval_stats = ""
769 if add_combined_eval_stats:
770 rows = []
771 for modelName, eval_stats in iter_combined_eval_stats_from_all_data_sets():
772 rows.append({"model_name": modelName, **eval_stats.metrics_dict()})
773 combined_stats_df = pd.DataFrame(rows)
774 combined_stats_df.set_index("model_name", drop=True, inplace=True)
775 combined_stats_df = combined_stats_df.loc[mean_results_df.index] # apply same sort order (index is model_name)
776 str_combined_eval_stats = f"Results on combined test data from all data sets:\n{combined_stats_df.to_string()}\n\n"
777 log.info(str_combined_eval_stats)
779 if result_writer is not None:
780 comparison_content = str_mean_results + "\n\n" + str_further_aggs + "\n\n" + str_combined_eval_stats + str_all_results
781 comparison_content += "\n\nModels [example instance]:\n\n"
782 comparison_content += "\n\n".join(f"{name} = {s}" for name, s in model_name_to_string_repr.items())
783 result_writer.write_text_file("model-comparison-results", comparison_content)
784 if write_csvs:
785 result_writer.write_data_frame_csv_file("all-results", all_results_df)
786 result_writer.write_data_frame_csv_file("mean-results", mean_results_df)
788 # create plots from combined data for each model
789 if create_combined_eval_stats_plots:
790 for modelName, eval_stats in iter_combined_eval_stats_from_all_data_sets():
791 child_result_writer = result_writer.child_with_added_prefix(modelName + "_") if result_writer is not None else None
792 result_collector = EvaluationResultCollector(show_plots=False, result_writer=child_result_writer)
793 plot_collector.create_plots(eval_stats, subtitle=modelName, result_collector=result_collector)
795 # collect results from visitors (if any)
796 result_collector = EvaluationResultCollector(show_plots=False, result_writer=result_writer)
797 if visitors is not None:
798 for visitor in visitors:
799 visitor.collect_results(result_collector)
801 # create result
802 dataset_names = list(self.io_data_dict.keys())
803 if is_regression:
804 mdmc_data = RegressionMultiDataModelComparisonData(all_results_df, mean_results_df, further_aggs_df, eval_stats_by_model_name,
805 results_by_model_name, dataset_names, model_name_to_string_repr)
806 else:
807 mdmc_data = ClassificationMultiDataModelComparisonData(all_results_df, mean_results_df, further_aggs_df,
808 eval_stats_by_model_name, results_by_model_name, dataset_names, model_name_to_string_repr)
810 # plot distributions
811 if create_metric_distribution_plots and result_writer is not None:
812 mdmc_data.create_distribution_plots(result_writer, cdf=distribution_plots_cdf,
813 cdf_complementary=distribution_plots_cdf_complementary)
815 return mdmc_data
818class ModelComparisonData:
819 @dataclass
820 class Result:
821 eval_data: Union[VectorClassificationModelEvaluationData, VectorRegressionModelEvaluationData] = None
822 cross_validation_data: Union[VectorClassificationModelCrossValidationData, VectorRegressionModelCrossValidationData] = None
824 def iter_evaluation_data(self) -> Iterator[Union[VectorClassificationModelEvaluationData, VectorRegressionModelEvaluationData]]:
825 if self.eval_data is not None:
826 yield self.eval_data
827 if self.cross_validation_data is not None:
828 yield from self.cross_validation_data.eval_data_list
830 def __init__(self, results_df: pd.DataFrame, results_by_model_name: Dict[str, Result], evaluator: Optional[VectorModelEvaluator] = None,
831 cross_validator: Optional[VectorModelCrossValidator] = None):
832 self.results_df = results_df
833 self.result_by_model_name = results_by_model_name
834 self.evaluator = evaluator
835 self.cross_validator = cross_validator
837 def get_best_model_name(self, metric_name: str) -> str:
838 idx = np.argmax(self.results_df[metric_name])
839 return self.results_df.index[idx]
841 def get_best_model(self, metric_name: str) -> Union[VectorClassificationModel, VectorRegressionModel, VectorModelBase]:
842 result = self.result_by_model_name[self.get_best_model_name(metric_name)]
843 if result.eval_data is None:
844 raise ValueError("The best model is not well-defined when using cross-validation")
845 return result.eval_data.model
848class ModelComparisonVisitor(ABC):
849 @abstractmethod
850 def visit(self, model_name: str, result: ModelComparisonData.Result):
851 pass
853 @abstractmethod
854 def collect_results(self, result_collector: EvaluationResultCollector) -> None:
855 """
856 Collects results (such as figures) at the end of the model comparison, based on the results collected
858 :param result_collector: the collector to which figures are to be added
859 """
860 pass
863class ModelComparisonVisitorAggregatedFeatureImportance(ModelComparisonVisitor):
864 """
865 During a model comparison, computes aggregated feature importance values for the model with the given name
866 """
867 def __init__(self, model_name: str, feature_agg_regex: Sequence[str] = (), write_figure=True, write_data_frame_csv=False):
868 r"""
869 :param model_name: the name of the model for which to compute the aggregated feature importance values
870 :param feature_agg_regex: a sequence of regular expressions describing which feature names to sum as one. Each regex must
871 contain exactly one group. If a regex matches a feature name, the feature importance will be summed under the key
872 of the matched group instead of the full feature name. For example, the regex r"(\w+)_\d+$" will cause "foo_1" and "foo_2"
873 to be summed under "foo" and similarly "bar_1" and "bar_2" to be summed under "bar".
874 """
875 self.model_name = model_name
876 self.agg_feature_importance = AggregatedFeatureImportance(feature_agg_reg_ex=feature_agg_regex)
877 self.write_figure = write_figure
878 self.write_data_frame_csv = write_data_frame_csv
880 def visit(self, model_name: str, result: ModelComparisonData.Result):
881 if model_name == self.model_name:
882 if result.cross_validation_data is not None:
883 models = result.cross_validation_data.trained_models
884 if models is not None:
885 for model in models:
886 self._collect(model)
887 else:
888 raise ValueError("Models were not returned in cross-validation results")
889 elif result.eval_data is not None:
890 self._collect(result.eval_data.model)
892 def _collect(self, model: Union[FeatureImportanceProvider, VectorModelBase]):
893 if not isinstance(model, FeatureImportanceProvider):
894 raise ValueError(f"Got model which does inherit from {FeatureImportanceProvider.__qualname__}: {model}")
895 self.agg_feature_importance.add(model.get_feature_importance_dict())
897 @deprecated("Use getFeatureImportance and create the plot using the returned object")
898 def plot_feature_importance(self) -> plt.Figure:
899 feature_importance_dict = self.agg_feature_importance.get_aggregated_feature_importance().get_feature_importance_dict()
900 return plot_feature_importance(feature_importance_dict, subtitle=self.model_name)
902 def get_feature_importance(self) -> FeatureImportance:
903 return self.agg_feature_importance.get_aggregated_feature_importance()
905 def collect_results(self, result_collector: EvaluationResultCollector):
906 feature_importance = self.get_feature_importance()
907 if self.write_figure:
908 result_collector.add_figure(f"{self.model_name}_feature-importance", feature_importance.plot())
909 if self.write_data_frame_csv:
910 result_collector.add_data_frame_csv_file(f"{self.model_name}_feature-importance", feature_importance.get_data_frame())
913class MultiDataModelComparisonData(Generic[TEvalStats, TEvalStatsCollection], ABC):
914 def __init__(self, all_results_df: pd.DataFrame,
915 mean_results_df: pd.DataFrame,
916 agg_results_df: pd.DataFrame,
917 eval_stats_by_model_name: Dict[str, List[TEvalStats]],
918 results_by_model_name: Dict[str, List[ModelComparisonData.Result]],
919 dataset_names: List[str],
920 model_name_to_string_repr: Dict[str, str]):
921 self.all_results_df = all_results_df
922 self.mean_results_df = mean_results_df
923 self.agg_results_df = agg_results_df
924 self.eval_stats_by_model_name = eval_stats_by_model_name
925 self.results_by_model_name = results_by_model_name
926 self.dataset_names = dataset_names
927 self.model_name_to_string_repr = model_name_to_string_repr
929 def get_model_names(self) -> List[str]:
930 return list(self.eval_stats_by_model_name.keys())
932 def get_model_description(self, model_name: str) -> str:
933 return self.model_name_to_string_repr[model_name]
935 def get_eval_stats_list(self, model_name: str) -> List[TEvalStats]:
936 return self.eval_stats_by_model_name[model_name]
938 @abstractmethod
939 def get_eval_stats_collection(self, model_name: str) -> TEvalStatsCollection:
940 pass
942 def iter_model_results(self, model_name: str) -> Iterator[Tuple[str, ModelComparisonData.Result]]:
943 results = self.results_by_model_name[model_name]
944 yield from zip(self.dataset_names, results)
946 def create_distribution_plots(self, result_writer: ResultWriter, cdf=True, cdf_complementary=False):
947 """
948 Creates plots of distributions of metrics across datasets for each model as a histogram, and additionally
949 any x-y plots (scatter plots & heat maps) for metrics that have associated paired metrics that were also computed
951 :param result_writer: the result writer
952 :param cdf: whether to additionally plot, for each distribution, the cumulative distribution function
953 :param cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that ``cdf`` is True
954 """
955 for modelName in self.get_model_names():
956 eval_stats_collection = self.get_eval_stats_collection(modelName)
957 for metricName in eval_stats_collection.get_metric_names():
958 # plot distribution
959 fig = eval_stats_collection.plot_distribution(metricName, subtitle=modelName, cdf=cdf, cdf_complementary=cdf_complementary)
960 result_writer.write_figure(f"{modelName}_dist-{metricName}", fig)
961 # scatter plot with paired metrics
962 metric: Metric = eval_stats_collection.get_metric_by_name(metricName)
963 for paired_metric in metric.get_paired_metrics():
964 if eval_stats_collection.has_metric(paired_metric):
965 fig = eval_stats_collection.plot_scatter(metric.name, paired_metric.name)
966 result_writer.write_figure(f"{modelName}_scatter-{metric.name}-{paired_metric.name}", fig)
967 fig = eval_stats_collection.plot_heat_map(metric.name, paired_metric.name)
968 result_writer.write_figure(f"{modelName}_heatmap-{metric.name}-{paired_metric.name}", fig)
971class ClassificationMultiDataModelComparisonData(MultiDataModelComparisonData[ClassificationEvalStats, ClassificationEvalStatsCollection]):
972 def get_eval_stats_collection(self, model_name: str):
973 return ClassificationEvalStatsCollection(self.get_eval_stats_list(model_name))
976class RegressionMultiDataModelComparisonData(MultiDataModelComparisonData[RegressionEvalStats, RegressionEvalStatsCollection]):
977 def get_eval_stats_collection(self, model_name: str):
978 return RegressionEvalStatsCollection(self.get_eval_stats_list(model_name))