Coverage for src/sensai/distance_metric.py: 36%

171 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import logging 

2import math 

3import os 

4from abc import abstractmethod, ABC 

5from typing import Generic, Sequence, Tuple, List, Union 

6 

7import numpy as np 

8import pandas as pd 

9 

10from .util import cache 

11from .util.cache import DelayedUpdateHook, TValue 

12from .util.string import object_repr 

13from .util.typing import PandasNamedTuple 

14 

15 

16log = logging.getLogger(__name__) 

17 

18 

19class DistanceMetric(ABC): 

20 """ 

21 Abstract base class for (symmetric) distance metrics 

22 """ 

23 @abstractmethod 

24 def distance(self, named_tuple_a: PandasNamedTuple, named_tuple_b: PandasNamedTuple) -> float: 

25 pass 

26 

27 @abstractmethod 

28 def __str__(self): 

29 super().__str__() 

30 

31 

32class SingleColumnDistanceMetric(DistanceMetric, ABC): 

33 def __init__(self, column: str): 

34 self.column = column 

35 

36 @abstractmethod 

37 def _distance(self, value_a, value_b) -> float: 

38 pass 

39 

40 def distance(self, named_tuple_a: PandasNamedTuple, named_tuple_b: PandasNamedTuple): 

41 value_a, value_b = getattr(named_tuple_a, self.column), getattr(named_tuple_b, self.column) 

42 return self._distance(value_a, value_b) 

43 

44 

45class DistanceMatrixDFCache(cache.PersistentKeyValueCache[Tuple[Union[str, int], Union[str, int]], TValue], Generic[TValue]): 

46 """A cache for distance matrices, which are stored as dataframes with identifiers as both index and columns""" 

47 def __init__(self, pickle_path: str, save_on_update: bool = True, deferred_save_delay_secs: float = 1.0): 

48 self.deferred_save_delay_secs = deferred_save_delay_secs 

49 self.save_on_update = save_on_update 

50 self.pickle_path = pickle_path 

51 if os.path.exists(self.pickle_path): 

52 self.distance_df = pd.read_pickle(self.pickle_path) 

53 log.info(f"Successfully loaded dataframe of shape {self.shape()} from cache. " 

54 f"There are {self.num_unfilled_entries()} unfilled entries") 

55 else: 

56 log.info(f"No cached distance dataframe found in {pickle_path}") 

57 self.distance_df = pd.DataFrame() 

58 self.cached_id_to_pos_dict = {identifier: pos for pos, identifier in enumerate(self.distance_df.index)} 

59 self._update_hook = DelayedUpdateHook(self.save, deferred_save_delay_secs) 

60 

61 def shape(self): 

62 n_entries = len(self.distance_df) 

63 return n_entries, n_entries 

64 

65 @staticmethod 

66 def _assert_tuple(key): 

67 assert isinstance(key, tuple) and len(key) == 2, f"Expected a tuple of two identifiers, instead got {key}" 

68 

69 def set(self, key: Tuple[Union[str, int], Union[str, int]], value: TValue): 

70 self._assert_tuple(key) 

71 for identifier in key: 

72 if identifier not in self.distance_df.columns: 

73 log.info(f"Adding new column and row for identifier {identifier}") 

74 self.distance_df[identifier] = np.nan 

75 self.distance_df.loc[identifier] = np.nan 

76 i1, i2 = key 

77 log.debug(f"Adding distance value for identifiers {i1}, {i2}") 

78 self.distance_df.loc[i1, i2] = self.distance_df.loc[i2, i1] = value 

79 if self.save_on_update: 

80 self._update_hook.handle_update() 

81 

82 def save(self): 

83 log.info(f"Saving new distance matrix to {self.pickle_path}") 

84 os.makedirs(os.path.dirname(self.pickle_path), exist_ok=True) 

85 self.distance_df.to_pickle(self.pickle_path) 

86 

87 def get(self, key: Tuple[Union[str, int], Union[str, int]]) -> TValue: 

88 self._assert_tuple(key) 

89 i1, i2 = key 

90 try: 

91 pos1, pos2 = self.cached_id_to_pos_dict[i1], self.cached_id_to_pos_dict[i2] 

92 except KeyError: 

93 return None 

94 result = self.distance_df.iloc[pos1, pos2] 

95 if np.isnan(result): 

96 return None 

97 return result 

98 

99 def num_unfilled_entries(self): 

100 return self.distance_df.isnull().sum().sum() 

101 

102 def get_all_cached(self, identifier: Union[str, int]): 

103 return self.distance_df[[identifier]] 

104 

105 

106class CachedDistanceMetric(DistanceMetric, cache.CachedValueProviderMixin): 

107 """ 

108 A decorator which provides caching for a distance metric, i.e. the metric is computed only if the 

109 value for the given pair of identifiers is not found within the persistent cache 

110 """ 

111 

112 def __init__(self, distance_metric: DistanceMetric, key_value_cache: cache.KeyValueCache, persist_cache=False): 

113 cache.CachedValueProviderMixin.__init__(self, key_value_cache, persist_cache=persist_cache) 

114 self.metric = distance_metric 

115 

116 def __getstate__(self): 

117 return cache.CachedValueProviderMixin.__getstate__(self) 

118 

119 def distance(self, named_tuple_a, named_tuple_b): 

120 id_a, id_b = named_tuple_a.Index, named_tuple_b.Index 

121 if id_b < id_a: 

122 id_a, id_b, named_tuple_a, named_tuple_b = id_b, id_a, named_tuple_b, named_tuple_a 

123 return self._provide_value((id_a, id_b), (named_tuple_a, named_tuple_b)) 

124 

125 def _compute_value(self, key: Tuple[Union[str, int], Union[str, int]], data: Tuple[PandasNamedTuple, PandasNamedTuple]): 

126 value_a, value_b = data 

127 return self.metric.distance(value_a, value_b) 

128 

129 def fill_cache(self, df_indexed_by_id: pd.DataFrame): 

130 """ 

131 Fill cache for all identifiers in the provided dataframe 

132 

133 Args: 

134 df_indexed_by_id: Dataframe that is indexed by identifiers of the members 

135 """ 

136 for position, valueA in enumerate(df_indexed_by_id.itertuples()): 

137 if position % 10 == 0: 

138 log.info(f"Processed {round(100 * position / len(df_indexed_by_id), 2)}%") 

139 for valueB in df_indexed_by_id[position + 1:].itertuples(): 

140 self.distance(valueA, valueB) 

141 

142 def __str__(self): 

143 return str(self.metric) 

144 

145 

146class LinearCombinationDistanceMetric(DistanceMetric): 

147 def __init__(self, metrics: Sequence[Tuple[float, DistanceMetric]]): 

148 """ 

149 :param metrics: a sequence of tuples (weight, distance metric) 

150 """ 

151 self.metrics = [(w, m) for (w, m) in metrics if w != 0] 

152 if len(self.metrics) == 0: 

153 raise ValueError(f"List of metrics is empty after removing all 0-weight metrics; passed {metrics}") 

154 

155 def distance(self, named_tuple_a, named_tuple_b): 

156 value = 0 

157 for weight, metric in self.metrics: 

158 value += metric.distance(named_tuple_a, named_tuple_b) * weight 

159 return value 

160 

161 def __str__(self): 

162 return f"Linear combination of {[(weight, str(metric)) for weight, metric in self.metrics]}" 

163 

164 

165class HellingerDistanceMetric(SingleColumnDistanceMetric): 

166 _SQRT2 = np.sqrt(2) 

167 

168 def __init__(self, column: str, check_input=False): 

169 super().__init__(column) 

170 self.check_input = check_input 

171 

172 def __str__(self): 

173 return object_repr(self, ["column"]) 

174 

175 def _check_input_value(self, input_value): 

176 if not isinstance(input_value, np.ndarray): 

177 raise ValueError(f"Expected to find numpy arrays in {self.column}") 

178 

179 if not math.isclose(input_value.sum(), 1): 

180 raise ValueError(f"The entries in {self.column} have to sum to 1") 

181 

182 if not all((input_value >= 0) * (input_value <= 1)): 

183 raise ValueError(f"The entries in {self.column} have to be in the range [0, 1]") 

184 

185 def _distance(self, value_a, value_b): 

186 if self.check_input: 

187 self._check_input_value(value_a) 

188 self._check_input_value(value_b) 

189 

190 return np.linalg.norm(np.sqrt(value_a) - np.sqrt(value_b)) / self._SQRT2 

191 

192 

193class EuclideanDistanceMetric(SingleColumnDistanceMetric): 

194 def __init__(self, column: str): 

195 super().__init__(column) 

196 

197 def _distance(self, value_a, value_b): 

198 return np.linalg.norm(value_a - value_b) 

199 

200 def __str__(self): 

201 return object_repr(self, ["column"]) 

202 

203 

204class IdentityDistanceMetric(DistanceMetric): 

205 def __init__(self, keys: Union[str, List[str]]): 

206 if not isinstance(keys, list): 

207 keys = [keys] 

208 assert keys != [], "At least one key has to be provided" 

209 self.keys = keys 

210 

211 def distance(self, named_tuple_a, named_tuple_b): 

212 for key in self.keys: 

213 if getattr(named_tuple_a, key) != getattr(named_tuple_b, key): 

214 return 1 

215 return 0 

216 

217 def __str__(self): 

218 return f"{self.__class__.__name__} based on keys: {self.keys}" 

219 

220 

221class RelativeBitwiseEqualityDistanceMetric(SingleColumnDistanceMetric): 

222 def __init__(self, column: str, check_input=False): 

223 super().__init__(column) 

224 self.check_input = check_input 

225 

226 def check_input_value(self, input_value): 

227 if not isinstance(input_value, np.ndarray): 

228 raise ValueError(f"Expected to find numpy arrays in {self.column}") 

229 

230 if not len(input_value.shape) == 1: 

231 raise ValueError(f"The input array should be of shape (n,)") 

232 

233 if not set(input_value).issubset({0, 1}): 

234 raise ValueError("The input array should only have entries in {0, 1}") 

235 

236 def _distance(self, value_a, value_b): 

237 if self.check_input: 

238 self.check_input_value(value_a) 

239 self.check_input_value(value_b) 

240 denom = np.count_nonzero(value_a + value_b) 

241 if denom == 0: 

242 return 0 

243 else: 

244 return 1-np.dot(value_a, value_b)/denom 

245 

246 def __str__(self): 

247 return f"{self.__class__.__name__} for column {self.column}"