Source code for sensai.evaluation.metric_computation
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union, List, Callable
from sensai import VectorRegressionModel, VectorClassificationModel, VectorModelBase
from sensai.evaluation import MultiDataModelEvaluation
from sensai.evaluation.eval_stats import RegressionMetric, ClassificationMetric
TMetric = Union[RegressionMetric, ClassificationMetric]
TModel = Union[VectorClassificationModel, VectorRegressionModel]
[docs]@dataclass
class MetricComputationResult:
metric_value: float
models: List[VectorModelBase]
[docs]class MetricComputation(ABC):
def __init__(self, metric: TMetric):
self.metric = metric
[docs] @abstractmethod
def compute_metric_value(self, model_factory: Callable[[], TModel]) -> MetricComputationResult:
pass
[docs]class MetricComputationMultiData(MetricComputation):
def __init__(self, ev_util: MultiDataModelEvaluation, use_cross_validation: bool, metric: TMetric,
use_combined_eval_stats: bool):
super().__init__(metric)
self.use_combined_eval_stats = use_combined_eval_stats
self.ev_util = ev_util
self.use_cross_validation = use_cross_validation
[docs] def compute_metric_value(self, model_factory: Callable[[], TModel]) -> MetricComputationResult:
result = self.ev_util.compare_models([model_factory], use_cross_validation=self.use_cross_validation)
if self.use_combined_eval_stats:
assert len(result.get_model_names()) == 1, "Model factory must produce named models"
model_name = result.get_model_names()[0]
metric_value = result.get_eval_stats_collection(model_name).get_combined_eval_stats().compute_metric_value(self.metric)
models = []
for dataset_name, comparison_result in result.iter_model_results(model_name):
if self.use_cross_validation:
models.extend(comparison_result.cross_validation_data.trained_models)
else:
models.append(comparison_result.eval_data.model)
return MetricComputationResult(metric_value, models)
else:
raise NotImplementedError()