Coverage for src/sensai/multi_model.py: 0%

17 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import functools 

2from typing import Union 

3 

4import pandas as pd 

5 

6from . import VectorRegressionModel 

7from .vector_model import RuleBasedVectorRegressionModel 

8 

9 

10class MultiVectorRegressionModel(RuleBasedVectorRegressionModel): 

11 """ 

12 Combines several (previously trained) regression models into a single regression model that produces the combined output of the 

13 individual models (concatenating their outputs) 

14 """ 

15 def __init__(self, *models: VectorRegressionModel): 

16 self.models = models 

17 predicted_variable_names_list = [m.get_predicted_variable_names() for m in models] 

18 predicted_variable_names = functools.reduce(lambda x, y: x + y.get_predicted_variable_names(), models, []) 

19 if len(predicted_variable_names) != sum((len(v) for v in predicted_variable_names_list)): 

20 raise ValueError(f"Models do not produce disjoint outputs: {predicted_variable_names_list}") 

21 super().__init__(predicted_variable_names) 

22 

23 def _predict(self, x: pd.DataFrame) -> Union[pd.DataFrame, list]: 

24 dfs = [m.predict(x) for m in self.models] 

25 combined_df = pd.concat(dfs, axis=1) 

26 return combined_df