Source code for sensai.xgboost
from typing import Optional
import xgboost
from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel, \
FeatureImportanceProviderSkLearnRegressionMultipleOneDim, FeatureImportanceProviderSkLearnClassification
[docs]def is_xgboost_version_at_least(major: int, minor: Optional[int] = None, patch: Optional[int] = None):
components = xgboost.__version__.split(".")
for i, version in enumerate((major, minor, patch)):
if version is not None:
installed_version = int(components[i])
if installed_version > version:
return True
if installed_version < version:
return False
return True
[docs]class XGBGradientBoostedVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel,
FeatureImportanceProviderSkLearnRegressionMultipleOneDim):
"""
XGBoost's regression model using gradient boosted trees
"""
def __init__(self, random_state=42, **model_args):
"""
:param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor
"""
super().__init__(xgboost.XGBRegressor, random_state=random_state, **model_args)
[docs]class XGBRandomForestVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel,
FeatureImportanceProviderSkLearnRegressionMultipleOneDim):
"""
XGBoost's random forest regression model
"""
def __init__(self, random_state=42, **model_args):
"""
:param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFRegressor
"""
super().__init__(xgboost.XGBRFRegressor, random_state=random_state, **model_args)
[docs]class XGBGradientBoostedVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification):
"""
XGBoost's classification model using gradient boosted trees
"""
def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args):
"""
:param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBClassifier
"""
use_label_encoding = is_xgboost_version_at_least(1, 6)
super().__init__(xgboost.XGBClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights,
use_label_encoding=use_label_encoding, **model_args)
[docs]class XGBRandomForestVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification):
"""
XGBoost's random forest classification model
"""
def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args):
"""
:param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFClassifier
"""
use_label_encoding = is_xgboost_version_at_least(1, 6)
super().__init__(xgboost.XGBRFClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights,
use_label_encoding=use_label_encoding, **model_args)