Coverage for src/sensai/evaluation/metric_computation.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from abc import ABC, abstractmethod 

2from dataclasses import dataclass 

3from typing import Union, List, Callable 

4 

5from sensai import VectorRegressionModel, VectorClassificationModel, VectorModelBase 

6from sensai.evaluation import MultiDataModelEvaluation 

7from sensai.evaluation.eval_stats import RegressionMetric, ClassificationMetric 

8 

9TMetric = Union[RegressionMetric, ClassificationMetric] 

10TModel = Union[VectorClassificationModel, VectorRegressionModel] 

11 

12 

13@dataclass 

14class MetricComputationResult: 

15 metric_value: float 

16 models: List[VectorModelBase] 

17 

18 

19class MetricComputation(ABC): 

20 def __init__(self, metric: TMetric): 

21 self.metric = metric 

22 

23 @abstractmethod 

24 def compute_metric_value(self, model_factory: Callable[[], TModel]) -> MetricComputationResult: 

25 pass 

26 

27 

28class MetricComputationMultiData(MetricComputation): 

29 def __init__(self, ev_util: MultiDataModelEvaluation, use_cross_validation: bool, metric: TMetric, 

30 use_combined_eval_stats: bool): 

31 super().__init__(metric) 

32 self.use_combined_eval_stats = use_combined_eval_stats 

33 self.ev_util = ev_util 

34 self.use_cross_validation = use_cross_validation 

35 

36 def compute_metric_value(self, model_factory: Callable[[], TModel]) -> MetricComputationResult: 

37 result = self.ev_util.compare_models([model_factory], use_cross_validation=self.use_cross_validation) 

38 if self.use_combined_eval_stats: 

39 assert len(result.get_model_names()) == 1, "Model factory must produce named models" 

40 model_name = result.get_model_names()[0] 

41 metric_value = result.get_eval_stats_collection(model_name).get_combined_eval_stats().compute_metric_value(self.metric) 

42 models = [] 

43 for dataset_name, comparison_result in result.iter_model_results(model_name): 

44 if self.use_cross_validation: 

45 models.extend(comparison_result.cross_validation_data.trained_models) 

46 else: 

47 models.append(comparison_result.eval_data.model) 

48 return MetricComputationResult(metric_value, models) 

49 else: 

50 raise NotImplementedError()