Source code for sensai.xgboost

import logging
from typing import Optional

import pandas as pd
import xgboost

from . import InputOutputData
from .data import DataSplitter
from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel, \
    FeatureImportanceProviderSkLearnRegressionMultipleOneDim, FeatureImportanceProviderSkLearnClassification, ActualFitParams
from .util.pickle import setstate

log = logging.getLogger(__name__)


[docs]def is_xgboost_version_at_least(major: int, minor: Optional[int] = None, patch: Optional[int] = None): components = xgboost.__version__.split(".") for i, version in enumerate((major, minor, patch)): if version is not None: installed_version = int(components[i]) if installed_version > version: return True if installed_version < version: return False return True
[docs]class XGBGradientBoostedVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel, FeatureImportanceProviderSkLearnRegressionMultipleOneDim): """ XGBoost's regression model using gradient boosted trees """ def __init__(self, random_state=42, early_stopping_rounds: Optional[int] = None, early_stopping_data_splitter: Optional[DataSplitter] = None, **model_args): """ :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor """ super().__init__(xgboost.XGBRegressor, random_state=random_state, early_stopping_rounds=early_stopping_rounds, **model_args) self.is_early_stopping_enabled = early_stopping_rounds is not None self.early_stopping_data_splitter = early_stopping_data_splitter def __setstate__(self, state): setstate(XGBGradientBoostedVectorRegressionModel, self, state, new_default_properties=dict( is_early_stopping_enabled=False, early_stopping_data_splitter=None))
[docs] def is_sample_weight_supported(self) -> bool: return True
def _compute_actual_fit_params(self, inputs: pd.DataFrame, outputs: pd.DataFrame, weights: Optional[pd.Series] = None) -> ActualFitParams: kwargs = {} if self.is_early_stopping_enabled: data = InputOutputData(inputs, outputs, weights=weights) train_data, val_data = self.early_stopping_data_splitter.split(data) train_data: InputOutputData kwargs["eval_set"] = [(val_data.inputs, val_data.outputs)] inputs = train_data.inputs outputs = train_data.outputs weights = train_data.weights log.info(f"Early stopping enabled with validation set of size {len(val_data)}") params = super()._compute_actual_fit_params(inputs, outputs, weights=weights) params.kwargs.update(kwargs) return params
[docs]class XGBRandomForestVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel, FeatureImportanceProviderSkLearnRegressionMultipleOneDim): """ XGBoost's random forest regression model """ def __init__(self, random_state=42, **model_args): """ :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFRegressor """ super().__init__(xgboost.XGBRFRegressor, random_state=random_state, **model_args)
[docs] def is_sample_weight_supported(self) -> bool: return True
[docs]class XGBGradientBoostedVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification): """ XGBoost's classification model using gradient boosted trees """ def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args): """ :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBClassifier """ use_label_encoding = is_xgboost_version_at_least(1, 6) super().__init__(xgboost.XGBClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights, use_label_encoding=use_label_encoding, **model_args)
[docs] def is_sample_weight_supported(self) -> bool: return True
[docs]class XGBRandomForestVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification): """ XGBoost's random forest classification model """ def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args): """ :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFClassifier """ use_label_encoding = is_xgboost_version_at_least(1, 6) super().__init__(xgboost.XGBRFClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights, use_label_encoding=use_label_encoding, **model_args)
[docs] def is_sample_weight_supported(self) -> bool: return True