Coverage for src/sensai/catboost.py: 0%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1from typing import Sequence, Union, Optional 

2import logging 

3import pandas as pd 

4import re 

5import catboost 

6 

7from .util.string import or_regex_group 

8from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel 

9 

10log = logging.getLogger(__name__) 

11 

12 

13# noinspection DuplicatedCode 

14class CatBoostVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel): 

15 log = log.getChild(__qualname__) 

16 

17 def __init__(self, categorical_feature_names: Optional[Union[Sequence[str], str]] = None, random_state=42, num_leaves=31, **model_args): 

18 """ 

19 :param categorical_feature_names: sequence of feature names in the input data that are categorical. 

20 Columns that have dtype 'category' (as will be the case for categorical columns created via FeatureGenerators) 

21 need not be specified (should be inferred automatically). 

22 In general, passing categorical features is preferable to using one-hot encoding, for example. 

23 :param random_state: the random seed to use 

24 :param num_leaves: the maximum number of leaves in one tree (original catboost default is 31) 

25 :param model_args: see https://catboost.ai/docs/concepts/python-reference_parameters-list.html#python-reference_parameters-list 

26 """ 

27 super().__init__(catboost.CatBoostRegressor, random_seed=random_state, num_leaves=num_leaves, **model_args) 

28 

29 if type(categorical_feature_names) == str: 

30 categorical_feature_name_regex = categorical_feature_names 

31 else: 

32 if categorical_feature_names is not None and len(categorical_feature_names) > 0: 

33 categorical_feature_name_regex = or_regex_group(categorical_feature_names) 

34 else: 

35 categorical_feature_name_regex = None 

36 self._categorical_feature_name_regex: str = categorical_feature_name_regex 

37 

38 def _update_model_args(self, inputs: pd.DataFrame, outputs: pd.DataFrame): 

39 if self._categorical_feature_name_regex is not None: 

40 cols = list(inputs.columns) 

41 categorical_feature_names = [col for col in cols if re.match(self._categorical_feature_name_regex, col)] 

42 col_indices = [cols.index(f) for f in categorical_feature_names] 

43 args = {"cat_features": col_indices} 

44 self.log.info(f"Updating model parameters with {args}") 

45 self.modelArgs.update(args) 

46 

47 def is_sample_weight_supported(self) -> bool: 

48 return True 

49 

50 

51# noinspection DuplicatedCode 

52class CatBoostVectorClassificationModel(AbstractSkLearnVectorClassificationModel): 

53 log = log.getChild(__qualname__) 

54 

55 def __init__(self, categorical_feature_names: Sequence[str] = None, random_state=42, num_leaves=31, **model_args): 

56 """ 

57 :param categorical_feature_names: sequence of feature names in the input data that are categorical 

58 Columns that have dtype 'category' (as will be the case for categorical columns created via FeatureGenerators) 

59 need not be specified (should be inferred automatically, but we have never actually tested this behaviour 

60 successfully for a classification model). 

61 In general, passing categorical features may be preferable to using one-hot encoding, for example. 

62 :param random_state: the random seed to use 

63 :param num_leaves: the maximum number of leaves in one tree (original catboost default is 31) 

64 :param model_args: see https://catboost.ai/docs/concepts/python-reference_parameters-list.html#python-reference_parameters-list 

65 """ 

66 super().__init__(catboost.CatBoostClassifier, random_seed=random_state, num_leaves=num_leaves, **model_args) 

67 

68 if type(categorical_feature_names) == str: 

69 categorical_feature_name_regex = categorical_feature_names 

70 else: 

71 if categorical_feature_names is not None and len(categorical_feature_names) > 0: 

72 categorical_feature_name_regex = or_regex_group(categorical_feature_names) 

73 else: 

74 categorical_feature_name_regex = None 

75 self._categorical_feature_name_regex: str = categorical_feature_name_regex 

76 

77 def _update_model_args(self, inputs: pd.DataFrame, outputs: pd.DataFrame): 

78 if self._categorical_feature_name_regex is not None: 

79 cols = list(inputs.columns) 

80 categorical_feature_names = [col for col in cols if re.match(self._categorical_feature_name_regex, col)] 

81 col_indices = [cols.index(f) for f in categorical_feature_names] 

82 args = {"cat_features": col_indices} 

83 self.log.info(f"Updating model parameters with {args}") 

84 self.modelArgs.update(args) 

85 

86 def is_sample_weight_supported(self) -> bool: 

87 return True