Source code for sensai.torch.torch_eval_util

from typing import Union

from . import TorchVectorRegressionModel
from ..evaluation import RegressionModelEvaluation
from ..evaluation.crossval import VectorModelCrossValidationData, VectorRegressionModelCrossValidationData
from ..evaluation.eval_util import EvaluationResultCollector
from ..evaluation.evaluator import VectorModelEvaluationData, VectorRegressionModelEvaluationData


[docs]class TorchVectorRegressionModelEvaluationUtil(RegressionModelEvaluation): def _create_plots(self, data: Union[VectorRegressionModelEvaluationData, VectorRegressionModelCrossValidationData], result_collector: EvaluationResultCollector, subtitle=None): super()._create_plots(data, result_collector, subtitle) if isinstance(data, VectorModelEvaluationData): self._add_loss_progression_plot_if_torch_vector_regression_model(data.model, "loss-progression", result_collector) elif isinstance(data, VectorModelCrossValidationData): if data.trained_models is not None: for i, model in enumerate(data.trained_models, start=1): self._add_loss_progression_plot_if_torch_vector_regression_model(model, f"loss-progression-{i}", result_collector) @staticmethod def _add_loss_progression_plot_if_torch_vector_regression_model(model, plot_name, result_collector): if isinstance(model, TorchVectorRegressionModel): result_collector.add_figure(plot_name, model.model.trainingInfo.plot_all())