Coverage for src/sensai/ensemble/ensemble_base.py: 36%
53 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1from abc import ABC, abstractmethod
2from concurrent.futures.process import ProcessPoolExecutor
3from typing import Sequence, List, Optional
4from inspect import currentframe, getframeinfo
6import pandas as pd
8from ..vector_model import VectorModel
9from ..util.multiprocessing import VectorModelWithSeparateFeatureGeneration
10from ..util.pickle import PickleFailureDebugger
13class EnsembleVectorModel(VectorModel, ABC):
14 def __init__(self, models: Sequence[VectorModel], num_processes=1):
15 """
16 :param models:
17 :param num_processes:
18 """
19 self.num_processes = num_processes
20 self.models = list(models)
21 super().__init__(check_input_columns=False)
23 def _fit(self, x: pd.DataFrame, y: pd.DataFrame, weights: Optional[pd.Series] = None):
24 self._warn_sample_weights_unsupported(False, weights)
26 if self.num_processes == 1 or len(self.models) == 1:
27 for model in self.models:
28 model.fit(x, y)
29 return
31 fitted_model_futures = []
32 executor = ProcessPoolExecutor(max_workers=self.num_processes)
33 fitters = [VectorModelWithSeparateFeatureGeneration(model) for model in self.models]
34 for fitter in fitters:
35 intermediate_step = fitter.fit_start(x, y)
36 frame_info = getframeinfo(currentframe())
37 PickleFailureDebugger.log_failure_if_enabled(intermediate_step,
38 context_info=f"Submitting {fitter} in {frame_info.filename}:{frame_info.lineno}")
39 fitted_model_futures.append(executor.submit(intermediate_step.execute))
40 for i, fittedModelFuture in enumerate(fitted_model_futures):
41 self.models[i] = fitters[i].fit_end(fittedModelFuture.result())
43 def compute_all_predictions(self, x: pd.DataFrame):
44 if self.num_processes == 1 or len(self.models) == 1:
45 return [model.predict(x) for model in self.models]
47 prediction_futures = []
48 executor = ProcessPoolExecutor(max_workers=self.num_processes)
49 predictors = [VectorModelWithSeparateFeatureGeneration(model) for model in self.models]
50 for predictor in predictors:
51 predict_finaliser = predictor.predict_start(x)
52 frame_info = getframeinfo(currentframe())
53 PickleFailureDebugger.log_failure_if_enabled(predict_finaliser,
54 context_info=f"Submitting {predict_finaliser} in {frame_info.filename}:{frame_info.lineno}")
55 prediction_futures.append(executor.submit(predict_finaliser.execute))
56 return [predictionFuture.result() for predictionFuture in prediction_futures]
58 def _predict(self, x):
59 predictions_data_frames = self.compute_all_predictions(x)
60 return self.aggregate_predictions(predictions_data_frames)
62 @abstractmethod
63 def aggregate_predictions(self, predictions_data_frames: List[pd.DataFrame]) -> pd.DataFrame:
64 pass
67class EnsembleRegressionVectorModel(EnsembleVectorModel, ABC):
68 def is_regression_model(self):
69 return True
72class EnsembleClassificationVectorModel(EnsembleVectorModel, ABC):
73 def is_regression_model(self):
74 return False