Coverage for src/sensai/catboost.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from typing import Sequence, Union, Optional
2import logging
3import pandas as pd
4import re
5import catboost
7from .util.string import or_regex_group
8from .sklearn.sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnVectorClassificationModel
10log = logging.getLogger(__name__)
13# noinspection DuplicatedCode
14class CatBoostVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel):
15 log = log.getChild(__qualname__)
17 def __init__(self, categorical_feature_names: Optional[Union[Sequence[str], str]] = None, random_state=42, num_leaves=31, **model_args):
18 """
19 :param categorical_feature_names: sequence of feature names in the input data that are categorical.
20 Columns that have dtype 'category' (as will be the case for categorical columns created via FeatureGenerators)
21 need not be specified (should be inferred automatically).
22 In general, passing categorical features is preferable to using one-hot encoding, for example.
23 :param random_state: the random seed to use
24 :param num_leaves: the maximum number of leaves in one tree (original catboost default is 31)
25 :param model_args: see https://catboost.ai/docs/concepts/python-reference_parameters-list.html#python-reference_parameters-list
26 """
27 super().__init__(catboost.CatBoostRegressor, random_seed=random_state, num_leaves=num_leaves, **model_args)
29 if type(categorical_feature_names) == str:
30 categorical_feature_name_regex = categorical_feature_names
31 else:
32 if categorical_feature_names is not None and len(categorical_feature_names) > 0:
33 categorical_feature_name_regex = or_regex_group(categorical_feature_names)
34 else:
35 categorical_feature_name_regex = None
36 self._categorical_feature_name_regex: str = categorical_feature_name_regex
38 def _update_model_args(self, inputs: pd.DataFrame, outputs: pd.DataFrame):
39 if self._categorical_feature_name_regex is not None:
40 cols = list(inputs.columns)
41 categorical_feature_names = [col for col in cols if re.match(self._categorical_feature_name_regex, col)]
42 col_indices = [cols.index(f) for f in categorical_feature_names]
43 args = {"cat_features": col_indices}
44 self.log.info(f"Updating model parameters with {args}")
45 self.modelArgs.update(args)
48# noinspection DuplicatedCode
49class CatBoostVectorClassificationModel(AbstractSkLearnVectorClassificationModel):
50 log = log.getChild(__qualname__)
52 def __init__(self, categorical_feature_names: Sequence[str] = None, random_state=42, num_leaves=31, **model_args):
53 """
54 :param categorical_feature_names: sequence of feature names in the input data that are categorical
55 Columns that have dtype 'category' (as will be the case for categorical columns created via FeatureGenerators)
56 need not be specified (should be inferred automatically, but we have never actually tested this behaviour
57 successfully for a classification model).
58 In general, passing categorical features may be preferable to using one-hot encoding, for example.
59 :param random_state: the random seed to use
60 :param num_leaves: the maximum number of leaves in one tree (original catboost default is 31)
61 :param model_args: see https://catboost.ai/docs/concepts/python-reference_parameters-list.html#python-reference_parameters-list
62 """
63 super().__init__(catboost.CatBoostClassifier, random_seed=random_state, num_leaves=num_leaves, **model_args)
65 if type(categorical_feature_names) == str:
66 categorical_feature_name_regex = categorical_feature_names
67 else:
68 if categorical_feature_names is not None and len(categorical_feature_names) > 0:
69 categorical_feature_name_regex = or_regex_group(categorical_feature_names)
70 else:
71 categorical_feature_name_regex = None
72 self._categorical_feature_name_regex: str = categorical_feature_name_regex
74 def _update_model_args(self, inputs: pd.DataFrame, outputs: pd.DataFrame):
75 if self._categorical_feature_name_regex is not None:
76 cols = list(inputs.columns)
77 categorical_feature_names = [col for col in cols if re.match(self._categorical_feature_name_regex, col)]
78 col_indices = [cols.index(f) for f in categorical_feature_names]
79 args = {"cat_features": col_indices}
80 self.log.info(f"Updating model parameters with {args}")
81 self.modelArgs.update(args)