Coverage for src/sensai/data_transformation/sklearn_transformer.py: 63%
49 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import functools
2import logging
3from typing import Optional, Sequence, Union, Any, Callable
5from sklearn.preprocessing import MaxAbsScaler, StandardScaler, RobustScaler, MinMaxScaler
6import numpy as np
7from typing_extensions import Protocol
9log = logging.getLogger(__name__)
11TransformableArray = Union[np.ndarray, Sequence[Sequence[Any]]]
14def to_2d_array(arr: TransformableArray) -> np.ndarray:
15 if not isinstance(arr, np.ndarray):
16 arr = np.array(arr)
17 if len(arr.shape) != 2:
18 raise ValueError(f"Got array of shape {arr.shape}; expected 2D array")
19 return arr
22class SkLearnTransformerProtocol(Protocol):
23 def inverse_transform(self, arr: TransformableArray) -> np.ndarray:
24 pass
26 def transform(self, arr: TransformableArray) -> np.ndarray:
27 pass
29 def fit(self, arr: TransformableArray):
30 pass
33class ManualScaler(SkLearnTransformerProtocol):
34 """
35 A scaler whose parameters are not learnt from data but manually defined
36 """
37 def __init__(self, centre: Optional[float] = None, scale: Optional[float] = None):
38 """
39 :param centre: the value to subtract from all values (if any)
40 :param scale: the value with which to scale all values (after removing the centre)
41 """
42 self.centre = centre if centre is not None else 0.0
43 self.scale = scale if scale is not None else 1.0
45 def fit(self, arr):
46 pass
48 def transform(self, arr: TransformableArray) -> np.ndarray:
49 arr = to_2d_array(arr)
50 return (arr - self.centre) * self.scale
52 def inverse_transform(self, arr: TransformableArray) -> np.ndarray:
53 arr = to_2d_array(arr)
54 return (arr / self.scale) + self.centre
57# noinspection PyPep8Naming
58class SkLearnTransformerFactoryFactory:
59 @staticmethod
60 def MaxAbsScaler() -> Callable[[], MaxAbsScaler]:
61 return MaxAbsScaler
63 @staticmethod
64 def MinMaxScaler() -> Callable[[], MinMaxScaler]:
65 return MinMaxScaler
67 @staticmethod
68 def StandardScaler(with_mean=True, with_std=True) -> Callable[[], StandardScaler]:
69 return functools.partial(StandardScaler, with_mean=with_mean, with_std=with_std)
71 @staticmethod
72 def RobustScaler(quantile_range=(25, 75), with_scaling=True, with_centering=True) -> Callable[[], RobustScaler]:
73 """
74 :param quantile_range: a tuple (a, b) where a and b > a (both in range 0..100) are the percentiles which determine the scaling.
75 Specifically, each value (after centering) is scaled with 1.0/(vb-va) where va and vb are the values corresponding to the
76 percentiles a and b respectively, such that, in the symmetric case where va and vb are equally far from the centre,
77 va will be transformed into -0.5 and vb into 0.5.
78 In a uniformly distributed data set ranging from `min` to `max`, the default values of a=25 and b=75 will thus result in
79 `min` being mapped to -1 and `max` being mapped to 1.
80 :param with_scaling: whether to apply scaling based on quantile_range.
81 :param with_centering: whether to apply centering by subtracting the median.
82 :return: a function, which when called without any arguments, produces the respective RobustScaler instance.
83 """
84 return functools.partial(RobustScaler, quantile_range=quantile_range, with_scaling=with_scaling, with_centering=with_centering)
86 @staticmethod
87 def ManualScaler(centre: Optional[float] = None, scale: Optional[float] = None) -> Callable[[], ManualScaler]:
88 """
89 :param centre: the value to subtract from all values (if any)
90 :param scale: the value with which to scale all values (after removing the centre)
91 :return: a function, which when called without any arguments, produces the respective scaler instance.
92 """
93 return functools.partial(ManualScaler, centre=centre, scale=scale)