Coverage for src/sensai/util/datastruct.py: 79%
216 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1from abc import ABC, abstractmethod
2from enum import Enum
3from typing import Sequence, Optional, TypeVar, Generic, Tuple, Dict, Any, TYPE_CHECKING
5from . import sequences as array_util
6from .string import ToStringMixin, dict_string
8if TYPE_CHECKING:
9 import pandas as pd
11TKey = TypeVar("TKey")
12TValue = TypeVar("TValue")
13TSortedKeyValueStructure = TypeVar("TSortedKeyValueStructure", bound="SortedKeyValueStructure")
16class Trivalent(Enum):
17 TRUE = "true"
18 FALSE = "false"
19 UNKNOWN = "unknown"
21 @classmethod
22 def from_bool(cls, b: bool):
23 return cls.TRUE if b else cls.FALSE
25 def is_true(self):
26 return self == Trivalent.TRUE
28 def is_false(self):
29 return self == Trivalent.FALSE
32class Maybe(Generic[TValue]):
33 def __init__(self, value: Optional[TValue]):
34 self.value = value
37class DeferredParams(ToStringMixin):
38 """
39 Represents a dictionary of parameters that is specifically designed to hold parameters that can only defined late within
40 a process (i.e. not initially at construction time), e.g. because the parameters are data-dependent and therefore can only
41 be determined once the data has been seen.
42 """
43 UNDEFINED = "__undefined__DeferredParams"
45 def __init__(self):
46 self.params = {}
48 def _tostring_object_info(self) -> str:
49 return dict_string(self.params)
51 def set_param(self, name: str, value: Any):
52 self.params[name] = value
54 def get_param(self, name, default=UNDEFINED):
55 """
56 :param name: the parameter name
57 :param default: in case no value is set, return this value, and if UNDEFINED (default), raise KeyError
58 :return: the parameter value
59 """
60 if default == self.UNDEFINED:
61 return self.params[name]
62 else:
63 return self.params.get(name, default)
65 def get_dict(self) -> Dict[str, Any]:
66 return self.params
69class SortedValues(Generic[TValue]):
70 """
71 Provides convenient binary search (bisection) operations for sorted sequences
72 """
73 def __init__(self, sorted_values: Sequence[TValue]):
74 self.values = sorted_values
76 def __len__(self):
77 return len(self.values)
79 def floor_index(self, value) -> Optional[int]:
80 """
81 Finds the rightmost index where the value is less than or equal to the given value
83 :param value: the value to search for
84 :return: the index or None if there is no such index
85 """
86 return array_util.floor_index(self.values, value)
88 def ceil_index(self, value) -> Optional[int]:
89 """
90 Finds the leftmost index where the value is greater than or equal to the given value
92 :param value: the value to search for
93 :return: the index or None if there is no such index
94 """
95 return array_util.ceil_index(self.values, value)
97 def closest_index(self, value) -> Optional[int]:
98 """
99 Finds the index of the value that is closest to the given value.
100 If two subsequent values have the same distance, the smaller index is returned.
102 :param value: the value to search for
103 :return: the index or None if this object is empty
104 """
105 return array_util.closest_index(self.values, value)
107 def _value(self, idx: Optional[int]) -> Optional[TValue]:
108 if idx is None:
109 return None
110 else:
111 return self.values[idx]
113 def floor_value(self, value) -> Optional[TValue]:
114 """
115 Finds the largest value that is less than or equal to the given value
117 :param value: the value to search for
118 :return: the value or None if there is no such value
119 """
120 return self._value(self.floor_index(value))
122 def ceil_value(self, value) -> Optional[TValue]:
123 """
124 Finds the smallest value that is greater than or equal to the given value
126 :param value: the value to search for
127 :return: the value or None if there is no such value
128 """
129 return self._value(self.ceil_index(value))
131 def closest_value(self, value) -> Optional[TValue]:
132 """
133 Finds the value that is closest to the given value.
134 If two subsequent values have the same distance, the smaller value is returned.
136 :param value: the value to search for
137 :return: the value or None if this object is empty
138 """
139 return self._value(self.closest_index(value))
141 def _value_slice(self, first_index, last_index):
142 if first_index is None or last_index is None:
143 return None
144 return self.values[first_index:last_index + 1]
146 def value_slice(self, lowest_key, highest_key) -> Optional[Sequence[TValue]]:
147 return self._value_slice(self.ceil_index(lowest_key), self.floor_index(highest_key))
150class SortedKeyValueStructure(Generic[TKey, TValue], ABC):
151 @abstractmethod
152 def __len__(self):
153 pass
155 @abstractmethod
156 def floor_index(self, key: TKey) -> Optional[int]:
157 """
158 Finds the rightmost index where the key value is less than or equal to the given value
160 :param key: the value to search for
161 :return: the index or None if there is no such index
162 """
163 pass
165 @abstractmethod
166 def ceil_index(self, key: TKey) -> Optional[int]:
167 """
168 Finds the leftmost index where the key value is greater than or equal to the given value
170 :param key: the value to search for
171 :return: the index or None if there is no such index
172 """
173 pass
175 @abstractmethod
176 def closest_index(self, key: TKey) -> Optional[int]:
177 """
178 Finds the index where the key is closest to the given value.
179 If two subsequent keys have the same distance, the smaller index is returned.
181 :param key: the value to search for
182 :return: the index or None if this object is empty.
183 """
184 pass
186 @abstractmethod
187 def floor_value(self, key: TKey) -> Optional[TValue]:
188 """
189 Returns the value for the largest index where the corresponding key is less than or equal to the given value
191 :param key: the key to search for
192 :return: the value or None if there is no such value
193 """
194 pass
196 @abstractmethod
197 def ceil_value(self, key: TKey) -> Optional[TValue]:
198 """
199 Returns the value for the smallest index where the corresponding key is greater than or equal to the given value
201 :param key: the key to search for
202 :return: the value or None if there is no such value
203 """
204 pass
206 @abstractmethod
207 def closest_value(self, key: TKey) -> Optional[TValue]:
208 """
209 Finds the value that is closest to the given value.
210 If two subsequent values have the same distance, the smaller value is returned.
212 :param key: the key to search for
213 :return: the value or None if this object is empty
214 """
215 pass
217 @abstractmethod
218 def floor_key_and_value(self, key: TKey) -> Optional[Tuple[TKey, TValue]]:
219 pass
221 @abstractmethod
222 def ceil_key_and_value(self, key: TKey) -> Optional[Tuple[TKey, TValue]]:
223 pass
225 @abstractmethod
226 def closest_key_and_value(self, key: TKey) -> Optional[Tuple[TKey, TValue]]:
227 pass
229 def interpolated_value(self, key: TKey) -> Optional[TValue]:
230 """
231 Computes a linearly interpolated value for the given key - based on the two closest key-value pairs found in the data structure.
232 If the key is found in the data structure, the corresponding value is directly returned.
234 NOTE: This operation is supported only for value types that support the required arithmetic operations.
236 :param key: the key for which the interpolated value is to be computed.
237 :return: the interpolated value or None if the data structure does not contain floor/ceil entries for the given key
238 """
239 fkv = self.floor_key_and_value(key)
240 ckv = self.ceil_key_and_value(key)
241 if fkv is None or ckv is None:
242 return None
243 floor_key, floor_value = fkv
244 ceil_key, ceil_value = ckv
245 if ceil_key == floor_key:
246 return floor_value
247 else:
248 frac = (key - floor_key) / (ceil_key - floor_key)
249 return floor_value + (ceil_value - floor_value) * frac
251 def slice(self: TSortedKeyValueStructure, lower_bound_key=None, upper_bound_key=None, inner=True) -> TSortedKeyValueStructure:
252 """
253 :param lower_bound_key: the key defining the start of the slice (depending on inner);
254 if None, the first included entry will be the very first entry
255 :param upper_bound_key: the key defining the end of the slice (depending on inner);
256 if None, the last included entry will be the very last entry
257 :param inner: if True, the returned slice will be within the bounds; if False, the returned
258 slice is extended by one entry in both directions such that it contains the bounds (where possible)
259 :return:
260 """
261 if lower_bound_key is not None and upper_bound_key is not None:
262 assert upper_bound_key >= lower_bound_key
263 if lower_bound_key is not None:
264 if inner:
265 from_index = self.ceil_index(lower_bound_key)
266 if from_index is None:
267 from_index = len(self) # shall return empty slice
268 else:
269 from_index = self.floor_index(lower_bound_key)
270 if from_index is None:
271 from_index = 0
272 else:
273 from_index = 0
274 if upper_bound_key is not None:
275 if inner:
276 to_index = self.floor_index(upper_bound_key)
277 if to_index is None:
278 to_index = -1 # shall return empty slice
279 else:
280 to_index = self.ceil_index(upper_bound_key)
281 if to_index is None:
282 to_index = len(self) - 1
283 else:
284 to_index = len(self) - 1
285 return self._create_slice(from_index, to_index)
287 @abstractmethod
288 def _create_slice(self: TSortedKeyValueStructure, from_index: int, to_index: int) -> TSortedKeyValueStructure:
289 pass
292class SortedKeysAndValues(Generic[TKey, TValue], SortedKeyValueStructure[TKey, TValue]):
293 def __init__(self, keys: Sequence[TKey], values: Sequence[TValue]):
294 """
295 :param keys: a sorted sequence of keys
296 :param values: a sequence of corresponding values
297 """
298 if len(keys) != len(values):
299 raise ValueError(f"Lengths of keys ({len(keys)}) and values ({len(values)}) do not match")
300 self.keys = keys
301 self.values = values
303 def __len__(self):
304 return len(self.keys)
306 @classmethod
307 def from_series(cls, s: "pd.Series"):
308 """
309 Creates an instance from a pandas Series, using the series' index as the keys and its values as the values
311 :param s: the series
312 :return: an instance
313 """
314 # noinspection PyTypeChecker
315 return cls(s.index, s.values)
317 def floor_index(self, key) -> Optional[int]:
318 return array_util.floor_index(self.keys, key)
320 def ceil_index(self, key) -> Optional[int]:
321 return array_util.ceil_index(self.keys, key)
323 def closest_index(self, key) -> Optional[int]:
324 return array_util.closest_index(self.keys, key)
326 def floor_value(self, key) -> Optional[TValue]:
327 return array_util.floor_value(self.keys, key, values=self.values)
329 def ceil_value(self, key) -> Optional[TValue]:
330 return array_util.ceil_value(self.keys, key, values=self.values)
332 def closest_value(self, key) -> Optional[TValue]:
333 return array_util.closest_value(self.keys, key, values=self.values)
335 def floor_key_and_value(self, key) -> Optional[Tuple[TKey, TValue]]:
336 idx = self.floor_index(key)
337 return None if idx is None else (self.keys[idx], self.values[idx])
339 def ceil_key_and_value(self, key) -> Optional[Tuple[TKey, TValue]]:
340 idx = self.ceil_index(key)
341 return None if idx is None else (self.keys[idx], self.values[idx])
343 def closest_key_and_value(self, key) -> Optional[Tuple[TKey, TValue]]:
344 idx = self.closest_index(key)
345 return None if idx is None else (self.keys[idx], self.values[idx])
347 def value_slice_inner(self, lower_bound_key, upper_bound_key):
348 return array_util.value_slice_inner(self.keys, lower_bound_key, upper_bound_key, values=self.values)
350 def value_slice_outer(self, lower_bound_key, upper_bound_key, fallback=False):
351 return array_util.value_slice_outer(self.keys, lower_bound_key, upper_bound_key, values=self.values, fallback_bounds=fallback)
353 def _create_slice(self, from_index: int, to_index: int) -> "SortedKeysAndValues":
354 return SortedKeysAndValues(self.keys[from_index:to_index + 1], self.values[from_index:to_index + 1])
357class SortedKeyValuePairs(Generic[TKey, TValue], SortedKeyValueStructure[TKey, TValue]):
358 @classmethod
359 def from_unsorted_key_value_pairs(cls, unsorted_key_value_pairs: Sequence[Tuple[TKey, TValue]]):
360 return cls(sorted(unsorted_key_value_pairs, key=lambda x: x[0]))
362 def __init__(self, sorted_key_value_pairs: Sequence[Tuple[TKey, TValue]]):
363 self.entries = sorted_key_value_pairs
364 self._sortedKeys = SortedValues([t[0] for t in sorted_key_value_pairs])
366 def __len__(self):
367 return len(self.entries)
369 def _value(self, idx: Optional[int]) -> Optional[TValue]:
370 if idx is None:
371 return None
372 return self.entries[idx][1]
374 def value_for_index(self, idx: int) -> TValue:
375 return self.entries[idx][1]
377 def key_for_index(self, idx: int) -> TKey:
378 return self.entries[idx][0]
380 def floor_index(self, key) -> Optional[int]:
381 """Finds the rightmost index where the key is less than or equal to the given key"""
382 return self._sortedKeys.floor_index(key)
384 def floor_value(self, key) -> Optional[TValue]:
385 return self._value(self.floor_index(key))
387 def floor_key_and_value(self, key) -> Optional[Tuple[TKey, TValue]]:
388 idx = self.floor_index(key)
389 return None if idx is None else self.entries[idx]
391 def ceil_index(self, key) -> Optional[int]:
392 """Find leftmost index where the key is greater than or equal to the given key"""
393 return self._sortedKeys.ceil_index(key)
395 def ceil_value(self, key) -> Optional[TValue]:
396 return self._value(self.ceil_index(key))
398 def ceil_key_and_value(self, key) -> Optional[Tuple[TKey, TValue]]:
399 idx = self.ceil_index(key)
400 return None if idx is None else self.entries[idx]
402 def closest_index(self, key) -> Optional[int]:
403 return self._sortedKeys.closest_index(key)
405 def closest_value(self, key) -> Optional[TValue]:
406 return self._value(self.closest_index(key))
408 def closest_key_and_value(self, key) -> Optional[Tuple[TKey, TValue]]:
409 idx = self.closest_index(key)
410 return None if idx is None else self.entries[idx]
412 def _value_slice(self, first_index, last_index):
413 if first_index is None or last_index is None:
414 return None
415 return [e[1] for e in self.entries[first_index:last_index + 1]]
417 def value_slice(self, lowest_key, highest_key) -> Optional[Sequence[TValue]]:
418 return self._value_slice(self.ceil_index(lowest_key), self.floor_index(highest_key))
420 def _create_slice(self, from_index: int, to_index: int) -> "SortedKeyValuePairs":
421 return SortedKeyValuePairs(self.entries[from_index:to_index + 1])