Coverage for src/sensai/ensemble/ensemble_base.py: 36%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1from abc import ABC, abstractmethod 

2from concurrent.futures.process import ProcessPoolExecutor 

3from typing import Sequence, List, Optional 

4from inspect import currentframe, getframeinfo 

5 

6import pandas as pd 

7 

8from ..vector_model import VectorModel 

9from ..util.multiprocessing import VectorModelWithSeparateFeatureGeneration 

10from ..util.pickle import PickleFailureDebugger 

11 

12 

13class EnsembleVectorModel(VectorModel, ABC): 

14 def __init__(self, models: Sequence[VectorModel], num_processes=1): 

15 """ 

16 :param models: 

17 :param num_processes: 

18 """ 

19 self.num_processes = num_processes 

20 self.models = list(models) 

21 super().__init__(check_input_columns=False) 

22 

23 def _fit(self, x: pd.DataFrame, y: pd.DataFrame, weights: Optional[pd.Series] = None): 

24 self._warn_sample_weights_unsupported(False, weights) 

25 

26 if self.num_processes == 1 or len(self.models) == 1: 

27 for model in self.models: 

28 model.fit(x, y) 

29 return 

30 

31 fitted_model_futures = [] 

32 executor = ProcessPoolExecutor(max_workers=self.num_processes) 

33 fitters = [VectorModelWithSeparateFeatureGeneration(model) for model in self.models] 

34 for fitter in fitters: 

35 intermediate_step = fitter.fit_start(x, y) 

36 frame_info = getframeinfo(currentframe()) 

37 PickleFailureDebugger.log_failure_if_enabled(intermediate_step, 

38 context_info=f"Submitting {fitter} in {frame_info.filename}:{frame_info.lineno}") 

39 fitted_model_futures.append(executor.submit(intermediate_step.execute)) 

40 for i, fittedModelFuture in enumerate(fitted_model_futures): 

41 self.models[i] = fitters[i].fit_end(fittedModelFuture.result()) 

42 

43 def compute_all_predictions(self, x: pd.DataFrame): 

44 if self.num_processes == 1 or len(self.models) == 1: 

45 return [model.predict(x) for model in self.models] 

46 

47 prediction_futures = [] 

48 executor = ProcessPoolExecutor(max_workers=self.num_processes) 

49 predictors = [VectorModelWithSeparateFeatureGeneration(model) for model in self.models] 

50 for predictor in predictors: 

51 predict_finaliser = predictor.predict_start(x) 

52 frame_info = getframeinfo(currentframe()) 

53 PickleFailureDebugger.log_failure_if_enabled(predict_finaliser, 

54 context_info=f"Submitting {predict_finaliser} in {frame_info.filename}:{frame_info.lineno}") 

55 prediction_futures.append(executor.submit(predict_finaliser.execute)) 

56 return [predictionFuture.result() for predictionFuture in prediction_futures] 

57 

58 def _predict(self, x): 

59 predictions_data_frames = self.compute_all_predictions(x) 

60 return self.aggregate_predictions(predictions_data_frames) 

61 

62 @abstractmethod 

63 def aggregate_predictions(self, predictions_data_frames: List[pd.DataFrame]) -> pd.DataFrame: 

64 pass 

65 

66 

67class EnsembleRegressionVectorModel(EnsembleVectorModel, ABC): 

68 def is_regression_model(self): 

69 return True 

70 

71 

72class EnsembleClassificationVectorModel(EnsembleVectorModel, ABC): 

73 def is_regression_model(self): 

74 return False