Coverage for src/sensai/torch/torch_eval_util.py: 0%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from typing import Union 

2 

3from . import TorchVectorRegressionModel 

4from ..evaluation import RegressionModelEvaluation 

5from ..evaluation.crossval import VectorModelCrossValidationData, VectorRegressionModelCrossValidationData 

6from ..evaluation.eval_util import EvaluationResultCollector 

7from ..evaluation.evaluator import VectorModelEvaluationData, VectorRegressionModelEvaluationData 

8 

9 

10class TorchVectorRegressionModelEvaluationUtil(RegressionModelEvaluation): 

11 

12 def _create_plots(self, 

13 data: Union[VectorRegressionModelEvaluationData, VectorRegressionModelCrossValidationData], 

14 result_collector: EvaluationResultCollector, 

15 subtitle=None): 

16 super()._create_plots(data, result_collector, subtitle) 

17 if isinstance(data, VectorModelEvaluationData): 

18 self._add_loss_progression_plot_if_torch_vector_regression_model(data.model, "loss-progression", result_collector) 

19 elif isinstance(data, VectorModelCrossValidationData): 

20 if data.trained_models is not None: 

21 for i, model in enumerate(data.trained_models, start=1): 

22 self._add_loss_progression_plot_if_torch_vector_regression_model(model, f"loss-progression-{i}", result_collector) 

23 

24 @staticmethod 

25 def _add_loss_progression_plot_if_torch_vector_regression_model(model, plot_name, result_collector): 

26 if isinstance(model, TorchVectorRegressionModel): 

27 result_collector.add_figure(plot_name, model.model.trainingInfo.plot_all())