Coverage for src/sensai/ensemble/ensemble_base.py: 37%
52 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from abc import ABC, abstractmethod
2from concurrent.futures.process import ProcessPoolExecutor
3from typing import Sequence, List
4from inspect import currentframe, getframeinfo
6import pandas as pd
8from ..vector_model import VectorModel
9from ..util.multiprocessing import VectorModelWithSeparateFeatureGeneration
10from ..util.pickle import PickleFailureDebugger
13class EnsembleVectorModel(VectorModel, ABC):
14 def __init__(self, models: Sequence[VectorModel], num_processes=1):
15 """
16 :param models:
17 :param num_processes:
18 """
19 self.num_processes = num_processes
20 self.models = list(models)
21 super().__init__(check_input_columns=False)
23 def _fit(self, x: pd.DataFrame, y: pd.DataFrame):
24 if self.num_processes == 1 or len(self.models) == 1:
25 for model in self.models:
26 model.fit(x, y)
27 return
29 fitted_model_futures = []
30 executor = ProcessPoolExecutor(max_workers=self.num_processes)
31 fitters = [VectorModelWithSeparateFeatureGeneration(model) for model in self.models]
32 for fitter in fitters:
33 intermediate_step = fitter.fit_start(x, y)
34 frame_info = getframeinfo(currentframe())
35 PickleFailureDebugger.log_failure_if_enabled(intermediate_step,
36 context_info=f"Submitting {fitter} in {frame_info.filename}:{frame_info.lineno}")
37 fitted_model_futures.append(executor.submit(intermediate_step.execute))
38 for i, fittedModelFuture in enumerate(fitted_model_futures):
39 self.models[i] = fitters[i].fit_end(fittedModelFuture.result())
41 def compute_all_predictions(self, x: pd.DataFrame):
42 if self.num_processes == 1 or len(self.models) == 1:
43 return [model.predict(x) for model in self.models]
45 prediction_futures = []
46 executor = ProcessPoolExecutor(max_workers=self.num_processes)
47 predictors = [VectorModelWithSeparateFeatureGeneration(model) for model in self.models]
48 for predictor in predictors:
49 predict_finaliser = predictor.predict_start(x)
50 frame_info = getframeinfo(currentframe())
51 PickleFailureDebugger.log_failure_if_enabled(predict_finaliser,
52 context_info=f"Submitting {predict_finaliser} in {frame_info.filename}:{frame_info.lineno}")
53 prediction_futures.append(executor.submit(predict_finaliser.execute))
54 return [predictionFuture.result() for predictionFuture in prediction_futures]
56 def _predict(self, x):
57 predictions_data_frames = self.compute_all_predictions(x)
58 return self.aggregate_predictions(predictions_data_frames)
60 @abstractmethod
61 def aggregate_predictions(self, predictions_data_frames: List[pd.DataFrame]) -> pd.DataFrame:
62 pass
65class EnsembleRegressionVectorModel(EnsembleVectorModel, ABC):
66 def is_regression_model(self):
67 return True
70class EnsembleClassificationVectorModel(EnsembleVectorModel, ABC):
71 def is_regression_model(self):
72 return False