Coverage for src/sensai/data_transformation/sklearn_transformer.py: 63%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import functools 

2import logging 

3from typing import Optional, Sequence, Union, Any, Callable 

4 

5from sklearn.preprocessing import MaxAbsScaler, StandardScaler, RobustScaler, MinMaxScaler 

6import numpy as np 

7from typing_extensions import Protocol 

8 

9log = logging.getLogger(__name__) 

10 

11TransformableArray = Union[np.ndarray, Sequence[Sequence[Any]]] 

12 

13 

14def to_2d_array(arr: TransformableArray) -> np.ndarray: 

15 if not isinstance(arr, np.ndarray): 

16 arr = np.array(arr) 

17 if len(arr.shape) != 2: 

18 raise ValueError(f"Got array of shape {arr.shape}; expected 2D array") 

19 return arr 

20 

21 

22class SkLearnTransformerProtocol(Protocol): 

23 def inverse_transform(self, arr: TransformableArray) -> np.ndarray: 

24 pass 

25 

26 def transform(self, arr: TransformableArray) -> np.ndarray: 

27 pass 

28 

29 def fit(self, arr: TransformableArray): 

30 pass 

31 

32 

33class ManualScaler(SkLearnTransformerProtocol): 

34 """ 

35 A scaler whose parameters are not learnt from data but manually defined 

36 """ 

37 def __init__(self, centre: Optional[float] = None, scale: Optional[float] = None): 

38 """ 

39 :param centre: the value to subtract from all values (if any) 

40 :param scale: the value with which to scale all values (after removing the centre) 

41 """ 

42 self.centre = centre if centre is not None else 0.0 

43 self.scale = scale if scale is not None else 1.0 

44 

45 def fit(self, arr): 

46 pass 

47 

48 def transform(self, arr: TransformableArray) -> np.ndarray: 

49 arr = to_2d_array(arr) 

50 return (arr - self.centre) * self.scale 

51 

52 def inverse_transform(self, arr: TransformableArray) -> np.ndarray: 

53 arr = to_2d_array(arr) 

54 return (arr / self.scale) + self.centre 

55 

56 

57# noinspection PyPep8Naming 

58class SkLearnTransformerFactoryFactory: 

59 @staticmethod 

60 def MaxAbsScaler() -> Callable[[], MaxAbsScaler]: 

61 return MaxAbsScaler 

62 

63 @staticmethod 

64 def MinMaxScaler() -> Callable[[], MinMaxScaler]: 

65 return MinMaxScaler 

66 

67 @staticmethod 

68 def StandardScaler(with_mean=True, with_std=True) -> Callable[[], StandardScaler]: 

69 return functools.partial(StandardScaler, with_mean=with_mean, with_std=with_std) 

70 

71 @staticmethod 

72 def RobustScaler(quantile_range=(25, 75), with_scaling=True, with_centering=True) -> Callable[[], RobustScaler]: 

73 """ 

74 :param quantile_range: a tuple (a, b) where a and b > a (both in range 0..100) are the percentiles which determine the scaling. 

75 Specifically, each value (after centering) is scaled with 1.0/(vb-va) where va and vb are the values corresponding to the 

76 percentiles a and b respectively, such that, in the symmetric case where va and vb are equally far from the centre, 

77 va will be transformed into -0.5 and vb into 0.5. 

78 In a uniformly distributed data set ranging from `min` to `max`, the default values of a=25 and b=75 will thus result in 

79 `min` being mapped to -1 and `max` being mapped to 1. 

80 :param with_scaling: whether to apply scaling based on quantile_range. 

81 :param with_centering: whether to apply centering by subtracting the median. 

82 :return: a function, which when called without any arguments, produces the respective RobustScaler instance. 

83 """ 

84 return functools.partial(RobustScaler, quantile_range=quantile_range, with_scaling=with_scaling, with_centering=with_centering) 

85 

86 @staticmethod 

87 def ManualScaler(centre: Optional[float] = None, scale: Optional[float] = None) -> Callable[[], ManualScaler]: 

88 """ 

89 :param centre: the value to subtract from all values (if any) 

90 :param scale: the value with which to scale all values (after removing the centre) 

91 :return: a function, which when called without any arguments, produces the respective scaler instance. 

92 """ 

93 return functools.partial(ManualScaler, centre=centre, scale=scale)