Coverage for src/sensai/evaluation/crossval.py: 40%
156 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import copy
2import functools
3import logging
4from abc import ABC, abstractmethod
5from typing import Tuple, Any, Generator, Generic, TypeVar, List, Union, Sequence, Optional
7import numpy as np
9from .eval_stats.eval_stats_base import PredictionEvalStats, EvalStatsCollection
10from .eval_stats.eval_stats_classification import ClassificationEvalStats, ClassificationEvalStatsCollection
11from .eval_stats.eval_stats_regression import RegressionEvalStats, RegressionEvalStatsCollection
12from .evaluator import VectorRegressionModelEvaluationData, VectorClassificationModelEvaluationData, \
13 VectorModelEvaluationData, VectorClassificationModelEvaluator, VectorRegressionModelEvaluator, \
14 MetricsDictProvider, VectorModelEvaluator, ClassificationEvaluatorParams, \
15 RegressionEvaluatorParams, MetricsDictProviderFromFunction
16from ..data import InputOutputData, DataSplitterFractional
17from ..tracking.tracking_base import TrackingContext
18from ..util.typing import PandasNamedTuple
19from ..vector_model import VectorClassificationModel, VectorRegressionModel, VectorModel
21log = logging.getLogger(__name__)
23TModel = TypeVar("TModel", bound=VectorModel)
24TEvalStats = TypeVar("TEvalStats", bound=PredictionEvalStats)
25TEvalStatsCollection = TypeVar("TEvalStatsCollection", bound=EvalStatsCollection)
26TEvalData = TypeVar("TEvalData", bound=VectorModelEvaluationData)
29class VectorModelCrossValidationData(ABC, Generic[TModel, TEvalData, TEvalStats, TEvalStatsCollection]):
30 def __init__(self, trained_models: Optional[List[TModel]], eval_data_list: List[TEvalData], predicted_var_names: List[str],
31 test_indices_list=None):
32 self.predicted_var_names = predicted_var_names
33 self.trained_models = trained_models
34 self.eval_data_list = eval_data_list
35 self.test_indices_list = test_indices_list
37 @property
38 def model_name(self):
39 return self.eval_data_list[0].model_name
41 @abstractmethod
42 def _create_eval_stats_collection(self, l: List[TEvalStats]) -> TEvalStatsCollection:
43 pass
45 def get_eval_stats_collection(self, predicted_var_name=None) -> TEvalStatsCollection:
46 if predicted_var_name is None:
47 if len(self.predicted_var_names) != 1:
48 raise Exception(f"Must provide name of predicted variable name, as multiple variables were predicted: "
49 f"{self.predicted_var_names}")
50 else:
51 predicted_var_name = self.predicted_var_names[0]
52 eval_stats_list = [evalData.get_eval_stats(predicted_var_name) for evalData in self.eval_data_list]
53 return self._create_eval_stats_collection(eval_stats_list)
55 def iter_input_output_ground_truth_tuples(self, predicted_var_name=None) -> Generator[Tuple[PandasNamedTuple, Any, Any], None, None]:
56 for evalData in self.eval_data_list:
57 eval_stats = evalData.get_eval_stats(predicted_var_name)
58 for i, namedTuple in enumerate(evalData.input_data.itertuples()):
59 yield namedTuple, eval_stats.y_predicted[i], eval_stats.y_true[i]
61 def track_metrics(self, tracking_context: TrackingContext):
62 is_multivar = len(self.predicted_var_names) > 1
63 for predicted_var_name in self.predicted_var_names:
64 eval_stats_collection = self.get_eval_stats_collection(predicted_var_name=predicted_var_name)
65 metrics_dict = eval_stats_collection.agg_metrics_dict()
66 tracking_context.track_metrics(metrics_dict, predicted_var_name=predicted_var_name if is_multivar else None)
69TCrossValData = TypeVar("TCrossValData", bound=VectorModelCrossValidationData)
72class CrossValidationSplitter(ABC):
73 """
74 Defines a mechanism with which to generate data splits for cross-validation
75 """
76 @abstractmethod
77 def create_folds(self, data: InputOutputData, num_folds: int) -> List[Tuple[Sequence[int], Sequence[int]]]:
78 """
79 :param data: the data from which to obtain the folds
80 :param num_folds: the number of splits/folds
81 :return: a list containing numFolds tuples (t, e) where t and e are sequences of data point indices to use for training
82 and evaluation respectively
83 """
84 pass
87class CrossValidationSplitterDefault(CrossValidationSplitter):
88 def __init__(self, shuffle=True, random_seed=42):
89 self.shuffle = shuffle
90 self.randomSeed = random_seed
92 def create_folds(self, data: InputOutputData, num_splits: int) -> List[Tuple[Sequence[int], Sequence[int]]]:
93 num_data_points = len(data)
94 num_test_points = num_data_points // num_splits
95 if self.shuffle:
96 indices = np.random.RandomState(self.randomSeed).permutation(num_data_points)
97 else:
98 indices = list(range(num_data_points))
99 result = []
100 for i in range(num_splits):
101 test_start_idx = i * num_test_points
102 test_end_idx = test_start_idx + num_test_points
103 test_indices = indices[test_start_idx:test_end_idx]
104 train_indices = np.concatenate((indices[:test_start_idx], indices[test_end_idx:]))
105 result.append((train_indices, test_indices))
106 return result
109class CrossValidationSplitterNested(CrossValidationSplitter):
110 """
111 A data splitter for nested cross-validation (which is useful, in particular, for time series prediction problems)
112 """
113 def __init__(self, test_fraction: float):
114 self.test_fraction = test_fraction
116 def create_folds(self, data: InputOutputData, num_folds: int) -> List[Tuple[Sequence[int], Sequence[int]]]:
117 fractional_splitter = DataSplitterFractional(1-self.test_fraction, shuffle=False)
118 result = []
119 for i in range(num_folds):
120 indices, (a, b) = fractional_splitter.split_with_indices(data)
121 result.append(indices)
122 data = a
123 return result
126class VectorModelCrossValidatorParams:
127 def __init__(self,
128 folds: int = 5,
129 splitter: CrossValidationSplitter = None,
130 return_trained_models=False,
131 evaluator_params: Union[RegressionEvaluatorParams, ClassificationEvaluatorParams] = None,
132 default_splitter_random_seed=42,
133 default_splitter_shuffle=True):
134 """
135 :param folds: the number of folds
136 :param splitter: the splitter to use in order to generate the folds; if None, use default split (using parameters for random seed
137 and shuffling below)
138 :param return_trained_models: whether to create a copy of the model for each fold and return each of the models
139 (requires that models can be deep-copied); if False, the model that is passed to evalModel is fitted several times
140 :param evaluator_params: the model evaluator parameters
141 :param default_splitter_random_seed: [if splitter is None] the random seed to use for splits
142 :param default_splitter_shuffle: [if splitter is None] whether to shuffle the data (using randomSeed) before creating the folds
143 """
144 self.folds = folds
145 self.evaluatorParams = evaluator_params
146 self.returnTrainedModels = return_trained_models
147 if splitter is None:
148 splitter = CrossValidationSplitterDefault(shuffle=default_splitter_shuffle, random_seed=default_splitter_random_seed)
149 self.splitter = splitter
152class VectorModelCrossValidator(MetricsDictProvider, Generic[TCrossValData], ABC):
153 def __init__(self, data: InputOutputData, params: Union[VectorModelCrossValidatorParams]):
154 """
155 :param data: the data set
156 :param params: parameters
157 """
158 self.params = params
159 self.modelEvaluators: List[VectorModelEvaluator] = []
160 for trainIndices, testIndices in self.params.splitter.create_folds(data, self.params.folds):
161 self.modelEvaluators.append(self._create_model_evaluator(data.filter_indices(trainIndices), data.filter_indices(testIndices)))
163 @staticmethod
164 def for_model(model: VectorModel, data: InputOutputData, params: VectorModelCrossValidatorParams) \
165 -> Union["VectorClassificationModelCrossValidator", "VectorRegressionModelCrossValidator"]:
166 if model.is_regression_model():
167 return VectorRegressionModelCrossValidator(data, params)
168 else:
169 return VectorClassificationModelCrossValidator(data, params)
171 @abstractmethod
172 def _create_model_evaluator(self, training_data: InputOutputData, test_data: InputOutputData) -> VectorModelEvaluator:
173 pass
175 @abstractmethod
176 def _create_result_data(self, trained_models, eval_data_list, test_indices_list, predicted_var_names) -> TCrossValData:
177 pass
179 def eval_model(self, model: VectorModel, track: bool = True):
180 """
181 :param model: the model to evaluate
182 :param track: whether tracking shall be enabled for the case where a tracked experiment is set on this object
183 :return: cross-validation results
184 """
185 trained_models = [] if self.params.returnTrainedModels else None
186 eval_data_list = []
187 test_indices_list = []
188 predicted_var_names = None
189 with self.begin_optional_tracking_context_for_model(model, track=track) as tracking_context:
190 for i, evaluator in enumerate(self.modelEvaluators, start=1):
191 evaluator: VectorModelEvaluator
192 log.info(f"Training and evaluating model with fold {i}/{len(self.modelEvaluators)} ...")
193 model_to_fit: VectorModel = copy.deepcopy(model) if self.params.returnTrainedModels else model
194 evaluator.fit_model(model_to_fit)
195 eval_data = evaluator.eval_model(model_to_fit)
196 if predicted_var_names is None:
197 predicted_var_names = eval_data.predicted_var_names
198 if self.params.returnTrainedModels:
199 trained_models.append(model_to_fit)
200 for predictedVarName in predicted_var_names:
201 log.info(f"Evaluation result for {predictedVarName}, fold {i}/{len(self.modelEvaluators)}: "
202 f"{eval_data.get_eval_stats(predicted_var_name=predictedVarName)}")
203 eval_data_list.append(eval_data)
204 test_indices_list.append(evaluator.test_data.outputs.index)
205 crossval_data = self._create_result_data(trained_models, eval_data_list, test_indices_list, predicted_var_names)
206 if tracking_context.is_enabled():
207 crossval_data.track_metrics(tracking_context)
208 return crossval_data
210 def _compute_metrics(self, model: VectorModel, **kwargs):
211 return self._compute_metrics_for_var_name(model, None)
213 def _compute_metrics_for_var_name(self, model, predicted_var_name: Optional[str]):
214 data = self.eval_model(model)
215 return data.get_eval_stats_collection(predicted_var_name=predicted_var_name).agg_metrics_dict()
217 def create_metrics_dict_provider(self, predicted_var_name: Optional[str]) -> MetricsDictProvider:
218 """
219 Creates a metrics dictionary provider, e.g. for use in hyperparameter optimisation
221 :param predicted_var_name: the name of the predicted variable for which to obtain evaluation metrics; may be None only
222 if the model outputs but a single predicted variable
223 :return: a metrics dictionary provider instance for the given variable
224 """
225 return MetricsDictProviderFromFunction(functools.partial(self._compute_metrics_for_var_name, predictedVarName=predicted_var_name))
228class VectorRegressionModelCrossValidationData(VectorModelCrossValidationData[VectorRegressionModel, VectorRegressionModelEvaluationData,
229 RegressionEvalStats, RegressionEvalStatsCollection]):
230 def _create_eval_stats_collection(self, l: List[RegressionEvalStats]) -> RegressionEvalStatsCollection:
231 return RegressionEvalStatsCollection(l)
234class VectorRegressionModelCrossValidator(VectorModelCrossValidator[VectorRegressionModelCrossValidationData]):
235 def _create_model_evaluator(self, training_data: InputOutputData, test_data: InputOutputData) -> VectorRegressionModelEvaluator:
236 evaluator_params = RegressionEvaluatorParams.from_dict_or_instance(self.params.evaluatorParams)
237 return VectorRegressionModelEvaluator(training_data, test_data=test_data, params=evaluator_params)
239 def _create_result_data(self, trained_models, eval_data_list, test_indices_list, predicted_var_names) \
240 -> VectorRegressionModelCrossValidationData:
241 return VectorRegressionModelCrossValidationData(trained_models, eval_data_list, predicted_var_names, test_indices_list)
244class VectorClassificationModelCrossValidationData(VectorModelCrossValidationData[VectorClassificationModel,
245 VectorClassificationModelEvaluationData, ClassificationEvalStats, ClassificationEvalStatsCollection]):
246 def _create_eval_stats_collection(self, l: List[ClassificationEvalStats]) -> ClassificationEvalStatsCollection:
247 return ClassificationEvalStatsCollection(l)
250class VectorClassificationModelCrossValidator(VectorModelCrossValidator[VectorClassificationModelCrossValidationData]):
251 def _create_model_evaluator(self, training_data: InputOutputData, test_data: InputOutputData):
252 evaluator_params = ClassificationEvaluatorParams.from_dict_or_instance(self.params.evaluatorParams)
253 return VectorClassificationModelEvaluator(training_data, test_data=test_data, params=evaluator_params)
255 def _create_result_data(self, trained_models, eval_data_list, test_indices_list, predicted_var_names) \
256 -> VectorClassificationModelCrossValidationData:
257 return VectorClassificationModelCrossValidationData(trained_models, eval_data_list, predicted_var_names, test_indices_list)