Coverage for src/sensai/data_transformation/value_transformation.py: 32%
19 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from typing import Sequence, Any
3import numpy as np
6class ValueOneHotEncoder:
7 """
8 A simple one-hot encoder, which encodes individual values.
9 A one-hot encoder transforms a categorical input value into an array whose length is the number of categories where all values
10 are zero except one whose value is one, indicating the category that is active.
11 """
12 def __init__(self, ignore_unknown=True):
13 """
14 :param ignore_unknown: whether unknown input values (not seen during fit) shall be ignored, resulting in an array of zeroes;
15 if False, throw an exception instead
16 """
17 self.categories = None
18 self.category2index = None
19 self.ignoreUnknown = ignore_unknown
21 def fit(self, values: Sequence[Any]):
22 unique_values = np.unique(values)
23 self.categories = sorted(unique_values)
24 self.category2index = {category: idx for idx, category in enumerate(self.categories)}
26 def transform(self, value) -> np.ndarray:
27 a = np.zeros(len(self.categories))
28 category_idx = self.category2index.get(value)
29 if category_idx is None:
30 if not self.ignoreUnknown:
31 raise Exception(f"Got unknown value '{value}'")
32 else:
33 a[category_idx] = 1.0
34 return a