Coverage for src/sensai/multi_model.py: 0%
17 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import functools
2from typing import Union
4import pandas as pd
6from . import VectorRegressionModel
7from .vector_model import RuleBasedVectorRegressionModel
10class MultiVectorRegressionModel(RuleBasedVectorRegressionModel):
11 """
12 Combines several (previously trained) regression models into a single regression model that produces the combined output of the
13 individual models (concatenating their outputs)
14 """
15 def __init__(self, *models: VectorRegressionModel):
16 self.models = models
17 predicted_variable_names_list = [m.get_predicted_variable_names() for m in models]
18 predicted_variable_names = functools.reduce(lambda x, y: x + y.get_predicted_variable_names(), models, [])
19 if len(predicted_variable_names) != sum((len(v) for v in predicted_variable_names_list)):
20 raise ValueError(f"Models do not produce disjoint outputs: {predicted_variable_names_list}")
21 super().__init__(predicted_variable_names)
23 def _predict(self, x: pd.DataFrame) -> Union[pd.DataFrame, list]:
24 dfs = [m.predict(x) for m in self.models]
25 combined_df = pd.concat(dfs, axis=1)
26 return combined_df