Coverage for src/sensai/xgboost.py: 0%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from typing import Optional 

2 

3import xgboost 

4 

5from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel, \ 

6 FeatureImportanceProviderSkLearnRegressionMultipleOneDim, FeatureImportanceProviderSkLearnClassification 

7 

8 

9def is_xgboost_version_at_least(major: int, minor: Optional[int] = None, patch: Optional[int] = None): 

10 components = xgboost.__version__.split(".") 

11 for i, version in enumerate((major, minor, patch)): 

12 if version is not None: 

13 installed_version = int(components[i]) 

14 if installed_version > version: 

15 return True 

16 if installed_version < version: 

17 return False 

18 return True 

19 

20 

21class XGBGradientBoostedVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel, 

22 FeatureImportanceProviderSkLearnRegressionMultipleOneDim): 

23 """ 

24 XGBoost's regression model using gradient boosted trees 

25 """ 

26 def __init__(self, random_state=42, **model_args): 

27 """ 

28 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor 

29 """ 

30 super().__init__(xgboost.XGBRegressor, random_state=random_state, **model_args) 

31 

32 

33class XGBRandomForestVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel, 

34 FeatureImportanceProviderSkLearnRegressionMultipleOneDim): 

35 """ 

36 XGBoost's random forest regression model 

37 """ 

38 def __init__(self, random_state=42, **model_args): 

39 """ 

40 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFRegressor 

41 """ 

42 super().__init__(xgboost.XGBRFRegressor, random_state=random_state, **model_args) 

43 

44 

45class XGBGradientBoostedVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification): 

46 """ 

47 XGBoost's classification model using gradient boosted trees 

48 """ 

49 def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args): 

50 """ 

51 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBClassifier 

52 """ 

53 use_label_encoding = is_xgboost_version_at_least(1, 6) 

54 super().__init__(xgboost.XGBClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights, 

55 use_label_encoding=use_label_encoding, **model_args) 

56 

57 

58class XGBRandomForestVectorClassificationModel(AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification): 

59 """ 

60 XGBoost's random forest classification model 

61 """ 

62 def __init__(self, random_state=42, use_balanced_class_weights=False, **model_args): 

63 """ 

64 :param model_args: See https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRFClassifier 

65 """ 

66 use_label_encoding = is_xgboost_version_at_least(1, 6) 

67 super().__init__(xgboost.XGBRFClassifier, random_state=random_state, use_balanced_class_weights=use_balanced_class_weights, 

68 use_label_encoding=use_label_encoding, **model_args)