Coverage for src/sensai/catboost.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from typing import Sequence, Union, Optional 

2import logging 

3import pandas as pd 

4import re 

5import catboost 

6 

7from .util.string import or_regex_group 

8from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel 

9 

10log = logging.getLogger(__name__) 

11 

12 

13# noinspection DuplicatedCode 

14class CatBoostVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel): 

15 log = log.getChild(__qualname__) 

16 

17 def __init__(self, categorical_feature_names: Optional[Union[Sequence[str], str]] = None, random_state=42, num_leaves=31, **model_args): 

18 """ 

19 :param categorical_feature_names: sequence of feature names in the input data that are categorical. 

20 Columns that have dtype 'category' (as will be the case for categorical columns created via FeatureGenerators) 

21 need not be specified (should be inferred automatically). 

22 In general, passing categorical features is preferable to using one-hot encoding, for example. 

23 :param random_state: the random seed to use 

24 :param num_leaves: the maximum number of leaves in one tree (original catboost default is 31) 

25 :param model_args: see https://catboost.ai/docs/concepts/python-reference_parameters-list.html#python-reference_parameters-list 

26 """ 

27 super().__init__(catboost.CatBoostRegressor, random_seed=random_state, num_leaves=num_leaves, **model_args) 

28 

29 if type(categorical_feature_names) == str: 

30 categorical_feature_name_regex = categorical_feature_names 

31 else: 

32 if categorical_feature_names is not None and len(categorical_feature_names) > 0: 

33 categorical_feature_name_regex = or_regex_group(categorical_feature_names) 

34 else: 

35 categorical_feature_name_regex = None 

36 self._categorical_feature_name_regex: str = categorical_feature_name_regex 

37 

38 def _update_model_args(self, inputs: pd.DataFrame, outputs: pd.DataFrame): 

39 if self._categorical_feature_name_regex is not None: 

40 cols = list(inputs.columns) 

41 categorical_feature_names = [col for col in cols if re.match(self._categorical_feature_name_regex, col)] 

42 col_indices = [cols.index(f) for f in categorical_feature_names] 

43 args = {"cat_features": col_indices} 

44 self.log.info(f"Updating model parameters with {args}") 

45 self.modelArgs.update(args) 

46 

47 

48# noinspection DuplicatedCode 

49class CatBoostVectorClassificationModel(AbstractSkLearnVectorClassificationModel): 

50 log = log.getChild(__qualname__) 

51 

52 def __init__(self, categorical_feature_names: Sequence[str] = None, random_state=42, num_leaves=31, **model_args): 

53 """ 

54 :param categorical_feature_names: sequence of feature names in the input data that are categorical 

55 Columns that have dtype 'category' (as will be the case for categorical columns created via FeatureGenerators) 

56 need not be specified (should be inferred automatically, but we have never actually tested this behaviour 

57 successfully for a classification model). 

58 In general, passing categorical features may be preferable to using one-hot encoding, for example. 

59 :param random_state: the random seed to use 

60 :param num_leaves: the maximum number of leaves in one tree (original catboost default is 31) 

61 :param model_args: see https://catboost.ai/docs/concepts/python-reference_parameters-list.html#python-reference_parameters-list 

62 """ 

63 super().__init__(catboost.CatBoostClassifier, random_seed=random_state, num_leaves=num_leaves, **model_args) 

64 

65 if type(categorical_feature_names) == str: 

66 categorical_feature_name_regex = categorical_feature_names 

67 else: 

68 if categorical_feature_names is not None and len(categorical_feature_names) > 0: 

69 categorical_feature_name_regex = or_regex_group(categorical_feature_names) 

70 else: 

71 categorical_feature_name_regex = None 

72 self._categorical_feature_name_regex: str = categorical_feature_name_regex 

73 

74 def _update_model_args(self, inputs: pd.DataFrame, outputs: pd.DataFrame): 

75 if self._categorical_feature_name_regex is not None: 

76 cols = list(inputs.columns) 

77 categorical_feature_names = [col for col in cols if re.match(self._categorical_feature_name_regex, col)] 

78 col_indices = [cols.index(f) for f in categorical_feature_names] 

79 args = {"cat_features": col_indices} 

80 self.log.info(f"Updating model parameters with {args}") 

81 self.modelArgs.update(args)