Coverage for src/sensai/xgboost.py: 0%
27 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from typing import Optional
3import xgboost
5from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel, \
6 FeatureImportanceProviderSkLearnRegressionMultipleOneDim, FeatureImportanceProviderSkLearnClassification
9def is_xgboost_version_at_least(major: int, minor: Optional[int] = None, patch: Optional[int] = None):
10 components = xgboost.__version__.split(".")
11 for i, version in enumerate((major, minor, patch)):
12 if version is not None:
13 installed_version = int(components[i])
14 if installed_version > version:
15 return True
16 if installed_version < version:
17 return False
18 return True
21class XGBGradientBoostedVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel,
22 FeatureImportanceProviderSkLearnRegressionMultipleOneDim):
23 """
24 XGBoost's regression model using gradient boosted trees
25 """
26 def __init__(self, random_state=42, **model_args):
27 """
28 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor
29 """
30 super().__init__(xgboost.XGBRegressor, random_state=random_state, **model_args)
33class XGBRandomForestVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel,
34 FeatureImportanceProviderSkLearnRegressionMultipleOneDim):
35 """
36 XGBoost's random forest regression model
37 """
38 def __init__(self, random_state=42, **model_args):
39 """
40 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFRegressor
41 """
42 super().__init__(xgboost.XGBRFRegressor, random_state=random_state, **model_args)
45class XGBGradientBoostedVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification):
46 """
47 XGBoost's classification model using gradient boosted trees
48 """
49 def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args):
50 """
51 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBClassifier
52 """
53 use_label_encoding = is_xgboost_version_at_least(1, 6)
54 super().__init__(xgboost.XGBClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights,
55 use_label_encoding=use_label_encoding, **model_args)
58class XGBRandomForestVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification):
59 """
60 XGBoost's random forest classification model
61 """
62 def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args):
63 """
64 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFClassifier
65 """
66 use_label_encoding = is_xgboost_version_at_least(1, 6)
67 super().__init__(xgboost.XGBRFClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights,
68 use_label_encoding=use_label_encoding, **model_args)