Coverage for src/sensai/torch/torch_eval_util.py: 0%
19 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1from typing import Union
3from . import TorchVectorRegressionModel
4from ..evaluation import RegressionModelEvaluation
5from ..evaluation.crossval import VectorModelCrossValidationData, VectorRegressionModelCrossValidationData
6from ..evaluation.eval_util import EvaluationResultCollector
7from ..evaluation.evaluator import VectorModelEvaluationData, VectorRegressionModelEvaluationData
10class TorchVectorRegressionModelEvaluationUtil(RegressionModelEvaluation):
12 def _create_plots(self,
13 data: Union[VectorRegressionModelEvaluationData, VectorRegressionModelCrossValidationData],
14 result_collector: EvaluationResultCollector,
15 subtitle=None):
16 super()._create_plots(data, result_collector, subtitle)
17 if isinstance(data, VectorModelEvaluationData):
18 self._add_loss_progression_plot_if_torch_vector_regression_model(data.model, "loss-progression", result_collector)
19 elif isinstance(data, VectorModelCrossValidationData):
20 if data.trained_models is not None:
21 for i, model in enumerate(data.trained_models, start=1):
22 self._add_loss_progression_plot_if_torch_vector_regression_model(model, f"loss-progression-{i}", result_collector)
24 @staticmethod
25 def _add_loss_progression_plot_if_torch_vector_regression_model(model, plot_name, result_collector):
26 if isinstance(model, TorchVectorRegressionModel):
27 result_collector.add_figure(plot_name, model.model.trainingInfo.plot_all())