Coverage for src/sensai/xgboost.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import logging 

2from typing import Optional 

3 

4import pandas as pd 

5import xgboost 

6 

7from . import InputOutputData 

8from .data import DataSplitter 

9from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel, \ 

10 FeatureImportanceProviderSkLearnRegressionMultipleOneDim, FeatureImportanceProviderSkLearnClassification, ActualFitParams 

11from .util.pickle import setstate 

12 

13log = logging.getLogger(__name__) 

14 

15 

16def is_xgboost_version_at_least(major: int, minor: Optional[int] = None, patch: Optional[int] = None): 

17 components = xgboost.__version__.split(".") 

18 for i, version in enumerate((major, minor, patch)): 

19 if version is not None: 

20 installed_version = int(components[i]) 

21 if installed_version > version: 

22 return True 

23 if installed_version < version: 

24 return False 

25 return True 

26 

27 

28class XGBGradientBoostedVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel, 

29 FeatureImportanceProviderSkLearnRegressionMultipleOneDim): 

30 """ 

31 XGBoost's regression model using gradient boosted trees 

32 """ 

33 

34 def __init__(self, random_state=42, 

35 early_stopping_rounds: Optional[int] = None, 

36 early_stopping_data_splitter: Optional[DataSplitter] = None, 

37 **model_args): 

38 """ 

39 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor 

40 """ 

41 super().__init__(xgboost.XGBRegressor, random_state=random_state, early_stopping_rounds=early_stopping_rounds, 

42 **model_args) 

43 self.is_early_stopping_enabled = early_stopping_rounds is not None 

44 self.early_stopping_data_splitter = early_stopping_data_splitter 

45 

46 def __setstate__(self, state): 

47 setstate(XGBGradientBoostedVectorRegressionModel, self, state, 

48 new_default_properties=dict( 

49 is_early_stopping_enabled=False, 

50 early_stopping_data_splitter=None)) 

51 

52 def is_sample_weight_supported(self) -> bool: 

53 return True 

54 

55 def _compute_actual_fit_params(self, inputs: pd.DataFrame, outputs: pd.DataFrame, weights: Optional[pd.Series] = None) -> ActualFitParams: 

56 kwargs = {} 

57 if self.is_early_stopping_enabled: 

58 data = InputOutputData(inputs, outputs, weights=weights) 

59 train_data, val_data = self.early_stopping_data_splitter.split(data) 

60 train_data: InputOutputData 

61 kwargs["eval_set"] = [(val_data.inputs, val_data.outputs)] 

62 inputs = train_data.inputs 

63 outputs = train_data.outputs 

64 weights = train_data.weights 

65 log.info(f"Early stopping enabled with validation set of size {len(val_data)}") 

66 params = super()._compute_actual_fit_params(inputs, outputs, weights=weights) 

67 params.kwargs.update(kwargs) 

68 return params 

69 

70 

71class XGBRandomForestVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel, 

72 FeatureImportanceProviderSkLearnRegressionMultipleOneDim): 

73 """ 

74 XGBoost's random forest regression model 

75 """ 

76 def __init__(self, random_state=42, **model_args): 

77 """ 

78 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFRegressor 

79 """ 

80 super().__init__(xgboost.XGBRFRegressor, random_state=random_state, **model_args) 

81 

82 def is_sample_weight_supported(self) -> bool: 

83 return True 

84 

85 

86class XGBGradientBoostedVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification): 

87 """ 

88 XGBoost's classification model using gradient boosted trees 

89 """ 

90 def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args): 

91 """ 

92 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBClassifier 

93 """ 

94 use_label_encoding = is_xgboost_version_at_least(1, 6) 

95 super().__init__(xgboost.XGBClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights, 

96 use_label_encoding=use_label_encoding, **model_args) 

97 

98 def is_sample_weight_supported(self) -> bool: 

99 return True 

100 

101 

102class XGBRandomForestVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification): 

103 """ 

104 XGBoost's random forest classification model 

105 """ 

106 def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args): 

107 """ 

108 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFClassifier 

109 """ 

110 use_label_encoding = is_xgboost_version_at_least(1, 6) 

111 super().__init__(xgboost.XGBRFClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights, 

112 use_label_encoding=use_label_encoding, **model_args) 

113 

114 def is_sample_weight_supported(self) -> bool: 

115 return True