Coverage for src/sensai/sklearn/sklearn_classification.py: 71%
49 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import logging
2from typing import Union, Optional
4import numpy as np
5import sklearn.ensemble
6import sklearn.naive_bayes
7import sklearn.neural_network
8import sklearn.tree
9from sklearn.ensemble import RandomForestClassifier
10from sklearn.tree import DecisionTreeClassifier
12from .sklearn_base import AbstractSkLearnVectorClassificationModel, FeatureImportanceProviderSkLearnClassification
14log = logging.getLogger(__name__)
17class SkLearnDecisionTreeVectorClassificationModel(AbstractSkLearnVectorClassificationModel):
18 def __init__(self, min_samples_leaf=1, random_state=42, **model_args):
19 super().__init__(DecisionTreeClassifier,
20 min_samples_leaf=min_samples_leaf, random_state=random_state, **model_args)
22 def is_sample_weight_supported(self) -> bool:
23 return True
26class SkLearnRandomForestVectorClassificationModel(AbstractSkLearnVectorClassificationModel,
27 FeatureImportanceProviderSkLearnClassification):
28 def __init__(self, n_estimators=100, min_samples_leaf=1, random_state=42, use_balanced_class_weights=False, **model_args):
29 super().__init__(RandomForestClassifier,
30 random_state=random_state, min_samples_leaf=min_samples_leaf, n_estimators=n_estimators,
31 use_balanced_class_weights=use_balanced_class_weights,
32 **model_args)
34 def is_sample_weight_supported(self) -> bool:
35 return True
38class SkLearnMLPVectorClassificationModel(AbstractSkLearnVectorClassificationModel):
39 def __init__(self, hidden_layer_sizes=(100,), activation: str = "relu",
40 solver: str = "adam", batch_size: Union[int, str] = "auto", random_state: Optional[int] = 42,
41 max_iter: int = 200, early_stopping: bool = False, n_iter_no_change: int = 10, **model_args):
42 """
43 :param hidden_layer_sizes: the sequence of hidden layer sizes
44 :param activation: {"identity", "logistic", "tanh", "relu"} the activation function to use for hidden layers (the one used for the
45 output layer is always 'identity')
46 :param solver: {"adam", "lbfgs", "sgd"} the name of the solver to apply
47 :param batch_size: the batch size or "auto" for min(200, data set size)
48 :param random_state: the random seed for reproducability; use None if it shall not be specifically defined
49 :param max_iter: the number of iterations (gradient steps for L-BFGS, epochs for other solvers)
50 :param early_stopping: whether to use early stopping (stop training after n_iter_no_change epochs without improvement)
51 :param n_iter_no_change: the number of iterations after which to stop early (if early_stopping is enabled)
52 :param model_args: additional arguments to pass on to MLPClassifier, see
53 https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html
54 """
55 super().__init__(sklearn.neural_network.MLPClassifier, hidden_layer_sizes=hidden_layer_sizes, activation=activation,
56 random_state=random_state, solver=solver, batch_size=batch_size, max_iter=max_iter, early_stopping=early_stopping,
57 n_iter_no_change=n_iter_no_change, **model_args)
59 def is_sample_weight_supported(self) -> bool:
60 return False
63class SkLearnMultinomialNBVectorClassificationModel(AbstractSkLearnVectorClassificationModel):
64 def __init__(self, **model_args):
65 super().__init__(sklearn.naive_bayes.MultinomialNB, **model_args)
67 def is_sample_weight_supported(self) -> bool:
68 return True
71class SkLearnSVCVectorClassificationModel(AbstractSkLearnVectorClassificationModel):
72 def __init__(self, random_state=42, **model_args):
73 super().__init__(sklearn.svm.SVC, random_state=random_state, **model_args)
75 def is_sample_weight_supported(self) -> bool:
76 return True
79class SkLearnLogisticRegressionVectorClassificationModel(AbstractSkLearnVectorClassificationModel):
80 def __init__(self, random_state=42, **model_args):
81 super().__init__(sklearn.linear_model.LogisticRegression, random_state=random_state, **model_args)
83 def is_sample_weight_supported(self) -> bool:
84 return True
87class SkLearnKNeighborsVectorClassificationModel(AbstractSkLearnVectorClassificationModel):
88 def __init__(self, **model_args):
89 super().__init__(sklearn.neighbors.KNeighborsClassifier, **model_args)
91 def _predict_sklearn(self, input_values):
92 # Apply a transformation to fix a bug in sklearn 1.3.0 (and perhaps earlier versions):
93 # https://github.com/scikit-learn/scikit-learn/issues/26768
94 inputs = np.ascontiguousarray(input_values)
96 return super()._predict_sklearn(inputs)
98 def is_sample_weight_supported(self) -> bool:
99 return False