Coverage for src/sensai/evaluation/metric_computation.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from abc import ABC, abstractmethod
2from dataclasses import dataclass
3from typing import Union, List, Callable
5from sensai import VectorRegressionModel, VectorClassificationModel, VectorModelBase
6from sensai.evaluation import MultiDataModelEvaluation
7from sensai.evaluation.eval_stats import RegressionMetric, ClassificationMetric
9TMetric = Union[RegressionMetric, ClassificationMetric]
10TModel = Union[VectorClassificationModel, VectorRegressionModel]
13@dataclass
14class MetricComputationResult:
15 metric_value: float
16 models: List[VectorModelBase]
19class MetricComputation(ABC):
20 def __init__(self, metric: TMetric):
21 self.metric = metric
23 @abstractmethod
24 def compute_metric_value(self, model_factory: Callable[[], TModel]) -> MetricComputationResult:
25 pass
28class MetricComputationMultiData(MetricComputation):
29 def __init__(self, ev_util: MultiDataModelEvaluation, use_cross_validation: bool, metric: TMetric,
30 use_combined_eval_stats: bool):
31 super().__init__(metric)
32 self.use_combined_eval_stats = use_combined_eval_stats
33 self.ev_util = ev_util
34 self.use_cross_validation = use_cross_validation
36 def compute_metric_value(self, model_factory: Callable[[], TModel]) -> MetricComputationResult:
37 result = self.ev_util.compare_models([model_factory], use_cross_validation=self.use_cross_validation)
38 if self.use_combined_eval_stats:
39 assert len(result.get_model_names()) == 1, "Model factory must produce named models"
40 model_name = result.get_model_names()[0]
41 metric_value = result.get_eval_stats_collection(model_name).get_combined_eval_stats().compute_metric_value(self.metric)
42 models = []
43 for dataset_name, comparison_result in result.iter_model_results(model_name):
44 if self.use_cross_validation:
45 models.extend(comparison_result.cross_validation_data.trained_models)
46 else:
47 models.append(comparison_result.eval_data.model)
48 return MetricComputationResult(metric_value, models)
49 else:
50 raise NotImplementedError()