Coverage for src/sensai/ensemble/ensemble_base.py: 37%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from abc import ABC, abstractmethod 

2from concurrent.futures.process import ProcessPoolExecutor 

3from typing import Sequence, List 

4from inspect import currentframe, getframeinfo 

5 

6import pandas as pd 

7 

8from ..vector_model import VectorModel 

9from ..util.multiprocessing import VectorModelWithSeparateFeatureGeneration 

10from ..util.pickle import PickleFailureDebugger 

11 

12 

13class EnsembleVectorModel(VectorModel, ABC): 

14 def __init__(self, models: Sequence[VectorModel], num_processes=1): 

15 """ 

16 :param models: 

17 :param num_processes: 

18 """ 

19 self.num_processes = num_processes 

20 self.models = list(models) 

21 super().__init__(check_input_columns=False) 

22 

23 def _fit(self, x: pd.DataFrame, y: pd.DataFrame): 

24 if self.num_processes == 1 or len(self.models) == 1: 

25 for model in self.models: 

26 model.fit(x, y) 

27 return 

28 

29 fitted_model_futures = [] 

30 executor = ProcessPoolExecutor(max_workers=self.num_processes) 

31 fitters = [VectorModelWithSeparateFeatureGeneration(model) for model in self.models] 

32 for fitter in fitters: 

33 intermediate_step = fitter.fit_start(x, y) 

34 frame_info = getframeinfo(currentframe()) 

35 PickleFailureDebugger.log_failure_if_enabled(intermediate_step, 

36 context_info=f"Submitting {fitter} in {frame_info.filename}:{frame_info.lineno}") 

37 fitted_model_futures.append(executor.submit(intermediate_step.execute)) 

38 for i, fittedModelFuture in enumerate(fitted_model_futures): 

39 self.models[i] = fitters[i].fit_end(fittedModelFuture.result()) 

40 

41 def compute_all_predictions(self, x: pd.DataFrame): 

42 if self.num_processes == 1 or len(self.models) == 1: 

43 return [model.predict(x) for model in self.models] 

44 

45 prediction_futures = [] 

46 executor = ProcessPoolExecutor(max_workers=self.num_processes) 

47 predictors = [VectorModelWithSeparateFeatureGeneration(model) for model in self.models] 

48 for predictor in predictors: 

49 predict_finaliser = predictor.predict_start(x) 

50 frame_info = getframeinfo(currentframe()) 

51 PickleFailureDebugger.log_failure_if_enabled(predict_finaliser, 

52 context_info=f"Submitting {predict_finaliser} in {frame_info.filename}:{frame_info.lineno}") 

53 prediction_futures.append(executor.submit(predict_finaliser.execute)) 

54 return [predictionFuture.result() for predictionFuture in prediction_futures] 

55 

56 def _predict(self, x): 

57 predictions_data_frames = self.compute_all_predictions(x) 

58 return self.aggregate_predictions(predictions_data_frames) 

59 

60 @abstractmethod 

61 def aggregate_predictions(self, predictions_data_frames: List[pd.DataFrame]) -> pd.DataFrame: 

62 pass 

63 

64 

65class EnsembleRegressionVectorModel(EnsembleVectorModel, ABC): 

66 def is_regression_model(self): 

67 return True 

68 

69 

70class EnsembleClassificationVectorModel(EnsembleVectorModel, ABC): 

71 def is_regression_model(self): 

72 return False