Coverage for src/sensai/distance_metric.py: 36%
171 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import logging
2import math
3import os
4from abc import abstractmethod, ABC
5from typing import Generic, Sequence, Tuple, List, Union
7import numpy as np
8import pandas as pd
10from .util import cache
11from .util.cache import DelayedUpdateHook, TValue
12from .util.string import object_repr
13from .util.typing import PandasNamedTuple
16log = logging.getLogger(__name__)
19class DistanceMetric(ABC):
20 """
21 Abstract base class for (symmetric) distance metrics
22 """
23 @abstractmethod
24 def distance(self, named_tuple_a: PandasNamedTuple, named_tuple_b: PandasNamedTuple) -> float:
25 pass
27 @abstractmethod
28 def __str__(self):
29 super().__str__()
32class SingleColumnDistanceMetric(DistanceMetric, ABC):
33 def __init__(self, column: str):
34 self.column = column
36 @abstractmethod
37 def _distance(self, value_a, value_b) -> float:
38 pass
40 def distance(self, named_tuple_a: PandasNamedTuple, named_tuple_b: PandasNamedTuple):
41 value_a, value_b = getattr(named_tuple_a, self.column), getattr(named_tuple_b, self.column)
42 return self._distance(value_a, value_b)
45class DistanceMatrixDFCache(cache.PersistentKeyValueCache[Tuple[Union[str, int], Union[str, int]], TValue], Generic[TValue]):
46 """A cache for distance matrices, which are stored as dataframes with identifiers as both index and columns"""
47 def __init__(self, pickle_path: str, save_on_update: bool = True, deferred_save_delay_secs: float = 1.0):
48 self.deferred_save_delay_secs = deferred_save_delay_secs
49 self.save_on_update = save_on_update
50 self.pickle_path = pickle_path
51 if os.path.exists(self.pickle_path):
52 self.distance_df = pd.read_pickle(self.pickle_path)
53 log.info(f"Successfully loaded dataframe of shape {self.shape()} from cache. "
54 f"There are {self.num_unfilled_entries()} unfilled entries")
55 else:
56 log.info(f"No cached distance dataframe found in {pickle_path}")
57 self.distance_df = pd.DataFrame()
58 self.cached_id_to_pos_dict = {identifier: pos for pos, identifier in enumerate(self.distance_df.index)}
59 self._update_hook = DelayedUpdateHook(self.save, deferred_save_delay_secs)
61 def shape(self):
62 n_entries = len(self.distance_df)
63 return n_entries, n_entries
65 @staticmethod
66 def _assert_tuple(key):
67 assert isinstance(key, tuple) and len(key) == 2, f"Expected a tuple of two identifiers, instead got {key}"
69 def set(self, key: Tuple[Union[str, int], Union[str, int]], value: TValue):
70 self._assert_tuple(key)
71 for identifier in key:
72 if identifier not in self.distance_df.columns:
73 log.info(f"Adding new column and row for identifier {identifier}")
74 self.distance_df[identifier] = np.nan
75 self.distance_df.loc[identifier] = np.nan
76 i1, i2 = key
77 log.debug(f"Adding distance value for identifiers {i1}, {i2}")
78 self.distance_df.loc[i1, i2] = self.distance_df.loc[i2, i1] = value
79 if self.save_on_update:
80 self._update_hook.handle_update()
82 def save(self):
83 log.info(f"Saving new distance matrix to {self.pickle_path}")
84 os.makedirs(os.path.dirname(self.pickle_path), exist_ok=True)
85 self.distance_df.to_pickle(self.pickle_path)
87 def get(self, key: Tuple[Union[str, int], Union[str, int]]) -> TValue:
88 self._assert_tuple(key)
89 i1, i2 = key
90 try:
91 pos1, pos2 = self.cached_id_to_pos_dict[i1], self.cached_id_to_pos_dict[i2]
92 except KeyError:
93 return None
94 result = self.distance_df.iloc[pos1, pos2]
95 if np.isnan(result):
96 return None
97 return result
99 def num_unfilled_entries(self):
100 return self.distance_df.isnull().sum().sum()
102 def get_all_cached(self, identifier: Union[str, int]):
103 return self.distance_df[[identifier]]
106class CachedDistanceMetric(DistanceMetric, cache.CachedValueProviderMixin):
107 """
108 A decorator which provides caching for a distance metric, i.e. the metric is computed only if the
109 value for the given pair of identifiers is not found within the persistent cache
110 """
112 def __init__(self, distance_metric: DistanceMetric, key_value_cache: cache.KeyValueCache, persist_cache=False):
113 cache.CachedValueProviderMixin.__init__(self, key_value_cache, persist_cache=persist_cache)
114 self.metric = distance_metric
116 def __getstate__(self):
117 return cache.CachedValueProviderMixin.__getstate__(self)
119 def distance(self, named_tuple_a, named_tuple_b):
120 id_a, id_b = named_tuple_a.Index, named_tuple_b.Index
121 if id_b < id_a:
122 id_a, id_b, named_tuple_a, named_tuple_b = id_b, id_a, named_tuple_b, named_tuple_a
123 return self._provide_value((id_a, id_b), (named_tuple_a, named_tuple_b))
125 def _compute_value(self, key: Tuple[Union[str, int], Union[str, int]], data: Tuple[PandasNamedTuple, PandasNamedTuple]):
126 value_a, value_b = data
127 return self.metric.distance(value_a, value_b)
129 def fill_cache(self, df_indexed_by_id: pd.DataFrame):
130 """
131 Fill cache for all identifiers in the provided dataframe
133 Args:
134 df_indexed_by_id: Dataframe that is indexed by identifiers of the members
135 """
136 for position, valueA in enumerate(df_indexed_by_id.itertuples()):
137 if position % 10 == 0:
138 log.info(f"Processed {round(100 * position / len(df_indexed_by_id), 2)}%")
139 for valueB in df_indexed_by_id[position + 1:].itertuples():
140 self.distance(valueA, valueB)
142 def __str__(self):
143 return str(self.metric)
146class LinearCombinationDistanceMetric(DistanceMetric):
147 def __init__(self, metrics: Sequence[Tuple[float, DistanceMetric]]):
148 """
149 :param metrics: a sequence of tuples (weight, distance metric)
150 """
151 self.metrics = [(w, m) for (w, m) in metrics if w != 0]
152 if len(self.metrics) == 0:
153 raise ValueError(f"List of metrics is empty after removing all 0-weight metrics; passed {metrics}")
155 def distance(self, named_tuple_a, named_tuple_b):
156 value = 0
157 for weight, metric in self.metrics:
158 value += metric.distance(named_tuple_a, named_tuple_b) * weight
159 return value
161 def __str__(self):
162 return f"Linear combination of {[(weight, str(metric)) for weight, metric in self.metrics]}"
165class HellingerDistanceMetric(SingleColumnDistanceMetric):
166 _SQRT2 = np.sqrt(2)
168 def __init__(self, column: str, check_input=False):
169 super().__init__(column)
170 self.check_input = check_input
172 def __str__(self):
173 return object_repr(self, ["column"])
175 def _check_input_value(self, input_value):
176 if not isinstance(input_value, np.ndarray):
177 raise ValueError(f"Expected to find numpy arrays in {self.column}")
179 if not math.isclose(input_value.sum(), 1):
180 raise ValueError(f"The entries in {self.column} have to sum to 1")
182 if not all((input_value >= 0) * (input_value <= 1)):
183 raise ValueError(f"The entries in {self.column} have to be in the range [0, 1]")
185 def _distance(self, value_a, value_b):
186 if self.check_input:
187 self._check_input_value(value_a)
188 self._check_input_value(value_b)
190 return np.linalg.norm(np.sqrt(value_a) - np.sqrt(value_b)) / self._SQRT2
193class EuclideanDistanceMetric(SingleColumnDistanceMetric):
194 def __init__(self, column: str):
195 super().__init__(column)
197 def _distance(self, value_a, value_b):
198 return np.linalg.norm(value_a - value_b)
200 def __str__(self):
201 return object_repr(self, ["column"])
204class IdentityDistanceMetric(DistanceMetric):
205 def __init__(self, keys: Union[str, List[str]]):
206 if not isinstance(keys, list):
207 keys = [keys]
208 assert keys != [], "At least one key has to be provided"
209 self.keys = keys
211 def distance(self, named_tuple_a, named_tuple_b):
212 for key in self.keys:
213 if getattr(named_tuple_a, key) != getattr(named_tuple_b, key):
214 return 1
215 return 0
217 def __str__(self):
218 return f"{self.__class__.__name__} based on keys: {self.keys}"
221class RelativeBitwiseEqualityDistanceMetric(SingleColumnDistanceMetric):
222 def __init__(self, column: str, check_input=False):
223 super().__init__(column)
224 self.check_input = check_input
226 def check_input_value(self, input_value):
227 if not isinstance(input_value, np.ndarray):
228 raise ValueError(f"Expected to find numpy arrays in {self.column}")
230 if not len(input_value.shape) == 1:
231 raise ValueError(f"The input array should be of shape (n,)")
233 if not set(input_value).issubset({0, 1}):
234 raise ValueError("The input array should only have entries in {0, 1}")
236 def _distance(self, value_a, value_b):
237 if self.check_input:
238 self.check_input_value(value_a)
239 self.check_input_value(value_b)
240 denom = np.count_nonzero(value_a + value_b)
241 if denom == 0:
242 return 0
243 else:
244 return 1-np.dot(value_a, value_b)/denom
246 def __str__(self):
247 return f"{self.__class__.__name__} for column {self.column}"