Coverage for src/sensai/xgboost.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import logging
2from typing import Optional
4import pandas as pd
5import xgboost
7from . import InputOutputData
8from .data import DataSplitter
9from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel, \
10 FeatureImportanceProviderSkLearnRegressionMultipleOneDim, FeatureImportanceProviderSkLearnClassification, ActualFitParams
11from .util.pickle import setstate
13log = logging.getLogger(__name__)
16def is_xgboost_version_at_least(major: int, minor: Optional[int] = None, patch: Optional[int] = None):
17 components = xgboost.__version__.split(".")
18 for i, version in enumerate((major, minor, patch)):
19 if version is not None:
20 installed_version = int(components[i])
21 if installed_version > version:
22 return True
23 if installed_version < version:
24 return False
25 return True
28class XGBGradientBoostedVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel,
29 FeatureImportanceProviderSkLearnRegressionMultipleOneDim):
30 """
31 XGBoost's regression model using gradient boosted trees
32 """
34 def __init__(self, random_state=42,
35 early_stopping_rounds: Optional[int] = None,
36 early_stopping_data_splitter: Optional[DataSplitter] = None,
37 **model_args):
38 """
39 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor
40 """
41 super().__init__(xgboost.XGBRegressor, random_state=random_state, early_stopping_rounds=early_stopping_rounds,
42 **model_args)
43 self.is_early_stopping_enabled = early_stopping_rounds is not None
44 self.early_stopping_data_splitter = early_stopping_data_splitter
46 def __setstate__(self, state):
47 setstate(XGBGradientBoostedVectorRegressionModel, self, state,
48 new_default_properties=dict(
49 is_early_stopping_enabled=False,
50 early_stopping_data_splitter=None))
52 def is_sample_weight_supported(self) -> bool:
53 return True
55 def _compute_actual_fit_params(self, inputs: pd.DataFrame, outputs: pd.DataFrame, weights: Optional[pd.Series] = None) -> ActualFitParams:
56 kwargs = {}
57 if self.is_early_stopping_enabled:
58 data = InputOutputData(inputs, outputs, weights=weights)
59 train_data, val_data = self.early_stopping_data_splitter.split(data)
60 train_data: InputOutputData
61 kwargs["eval_set"] = [(val_data.inputs, val_data.outputs)]
62 inputs = train_data.inputs
63 outputs = train_data.outputs
64 weights = train_data.weights
65 log.info(f"Early stopping enabled with validation set of size {len(val_data)}")
66 params = super()._compute_actual_fit_params(inputs, outputs, weights=weights)
67 params.kwargs.update(kwargs)
68 return params
71class XGBRandomForestVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel,
72 FeatureImportanceProviderSkLearnRegressionMultipleOneDim):
73 """
74 XGBoost's random forest regression model
75 """
76 def __init__(self, random_state=42, **model_args):
77 """
78 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFRegressor
79 """
80 super().__init__(xgboost.XGBRFRegressor, random_state=random_state, **model_args)
82 def is_sample_weight_supported(self) -> bool:
83 return True
86class XGBGradientBoostedVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification):
87 """
88 XGBoost's classification model using gradient boosted trees
89 """
90 def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args):
91 """
92 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBClassifier
93 """
94 use_label_encoding = is_xgboost_version_at_least(1, 6)
95 super().__init__(xgboost.XGBClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights,
96 use_label_encoding=use_label_encoding, **model_args)
98 def is_sample_weight_supported(self) -> bool:
99 return True
102class XGBRandomForestVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification):
103 """
104 XGBoost's random forest classification model
105 """
106 def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args):
107 """
108 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFClassifier
109 """
110 use_label_encoding = is_xgboost_version_at_least(1, 6)
111 super().__init__(xgboost.XGBRFClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights,
112 use_label_encoding=use_label_encoding, **model_args)
114 def is_sample_weight_supported(self) -> bool:
115 return True