Coverage for src/sensai/clustering/greedy_clustering.py: 30%
112 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1# -*- coding: utf-8 -*-
2import logging
3import math
4import queue
5from abc import ABC, abstractmethod
6from typing import List, Sequence, Iterator, Callable, Optional, Union, Tuple
8log = logging.getLogger(__name__)
11class GreedyAgglomerativeClustering(object):
12 """
13 An implementation of greedy agglomerative clustering which avoids unnecessary
14 recomputations of merge costs through the management of a priority queue of
15 potential merges.
17 Greedy agglomerative clustering works as follows. Starting with an initial
18 set of clusters (where each cluster typically contains a single data point),
19 the method successively merges the two clusters where the merge cost is lowest (greedy),
20 until no further merges are admissible.
21 The merge operation is a mutating operation, i.e. the initial clusters are modified.
23 To apply the method, the Cluster class must be subclassed, so as to define
24 what the cost of a merge in your application shall be and how two clusters can be merged.
25 For example, if data points are points in a Cartesian coordinate system, then the merge cost
26 can be defined as the minimum or maximum distance among all pairs of points in the two clusters,
27 admissibility being determined by a threshold that must not be exceeded;
28 the merge operation can simply concatenate lists of data points.
29 """
30 log = log.getChild(__qualname__)
32 class Cluster(ABC):
33 """
34 Base class for clusters that can be merged via GreedyAgglomerativeClustering
35 """
36 @abstractmethod
37 def merge_cost(self, other) -> float:
38 """
39 Computes the cost of merging the given cluster with this cluster
41 :return: the (non-negative) merge cost or math.inf if a merge is inadmissible"""
42 pass
44 @abstractmethod
45 def merge(self, other):
46 """
47 Merges the given cluster into this cluster"
49 :param other: the cluster that is to be merged into this cluster
50 """
51 pass
53 def __init__(self, clusters: Sequence[Cluster],
54 merge_candidate_determination_strategy: "GreedyAgglomerativeClustering.MergeCandidateDeterminationStrategy" = None):
55 """
56 :param clusters: the initial clusters, which are to be agglomerated into larger clusters
57 """
58 self.prioritised_merges = queue.PriorityQueue()
59 self.wrapped_clusters = []
60 for idx, c in enumerate(clusters):
61 self.wrapped_clusters.append(GreedyAgglomerativeClustering.WrappedCluster(c, idx, self))
63 # initialise merge candidate determination strategy
64 if merge_candidate_determination_strategy is None:
65 merge_candidate_determination_strategy = self.MergeCandidateDeterminationStrategyDefault()
66 merge_candidate_determination_strategy.set_clusterer(self)
67 self.mergeCandidateDeterminationStrategy = merge_candidate_determination_strategy
69 def apply_clustering(self) -> List[Cluster]:
70 """
71 Applies greedy agglomerative clustering to the clusters given at construction, merging
72 clusters until no further merges are admissible
74 :return: the list of agglomerated clusters (subset of the original clusters, which may have had other
75 clusters merged into them)
76 """
77 # compute all possible merges, adding them to the priority queue
78 self.log.debug("Computing initial merges")
79 for idx, wc in enumerate(self.wrapped_clusters):
80 self.log.debug("Computing potential merges for cluster index %d" % idx)
81 wc.compute_merges(True)
83 # greedily apply the least-cost merges
84 steps = 0
85 while not self.prioritised_merges.empty():
86 self.log.debug("Clustering step %d" % (steps+1))
87 have_merge = False
88 while not have_merge and not self.prioritised_merges.empty():
89 merge = self.prioritised_merges.get()
90 if not merge.evaporated:
91 have_merge = True
92 if have_merge:
93 merge.apply()
94 steps += 1
96 result = filter(lambda wc: not wc.is_merged(), self.wrapped_clusters)
97 result = list(map(lambda wc: wc.cluster, result))
98 return result
100 class WrappedCluster(object):
101 """
102 Wrapper for clusters which stores additional data required for clustering (internal use only)
103 """
104 def __init__(self, cluster, idx, clusterer: "GreedyAgglomerativeClustering"):
105 self.merged_into_cluster: Optional[GreedyAgglomerativeClustering.WrappedCluster] = None
106 self.merges = []
107 self.cluster = cluster
108 self.idx = idx
109 self.clusterer = clusterer
111 def is_merged(self) -> bool:
112 return self.merged_into_cluster is not None
114 def get_cluster_association(self) -> "GreedyAgglomerativeClustering.WrappedCluster":
115 """
116 Gets the wrapped cluster that this cluster's points have ultimately been merged into (which may be the cluster itself)
118 :return: the wrapped cluster this cluster's points are associated with
119 """
120 if self.merged_into_cluster is None:
121 return self
122 else:
123 return self.merged_into_cluster.get_cluster_association()
125 def remove_merges(self):
126 for merge in self.merges:
127 merge.evaporated = True
128 self.merges = []
130 def compute_merges(self, initial: bool, merged_cluster_indices: Tuple[int, int] = None):
131 # add new merges to queue
132 wrapped_clusters = self.clusterer.wrapped_clusters
133 for item in self.clusterer.mergeCandidateDeterminationStrategy.iter_candidate_indices(self, initial, merged_cluster_indices):
134 merge: Optional[GreedyAgglomerativeClustering.ClusterMerge] = None
135 if type(item) == int:
136 other_idx = item
137 if other_idx != self.idx:
138 other = wrapped_clusters[other_idx]
139 if not other.is_merged():
140 merge_cost = self.cluster.merge_cost(other.cluster)
141 if not math.isinf(merge_cost):
142 merge = GreedyAgglomerativeClustering.ClusterMerge(self, other, merge_cost)
143 else:
144 merge = item
145 assert merge.c1.idx == self.idx
146 if merge is not None:
147 merge.c1.merges.append(merge)
148 merge.c2.merges.append(merge)
149 self.clusterer.prioritised_merges.put(merge)
151 def __str__(self):
152 return "Cluster[idx=%d]" % self.idx
154 class ClusterMerge(object):
155 """
156 Represents a potential merge
157 """
158 log = log.getChild(__qualname__)
160 def __init__(self, c1: "GreedyAgglomerativeClustering.WrappedCluster", c2: "GreedyAgglomerativeClustering.WrappedCluster",
161 merge_cost):
162 self.c1 = c1
163 self.c2 = c2
164 self.merge_cost = merge_cost
165 self.evaporated = False
167 def apply(self):
168 c1, c2 = self.c1, self.c2
169 self.log.debug("Merging %s into %s..." % (str(c1), str(c2)))
170 c1.cluster.merge(c2.cluster)
171 c2.merged_into_cluster = c1
172 c1.remove_merges()
173 c2.remove_merges()
174 self.log.debug("Computing new merge costs for %s..." % str(c1))
175 c1.compute_merges(False, merged_cluster_indices=(c1.idx, c2.idx))
177 def __lt__(self, other):
178 return self.merge_cost < other.merge_cost
180 class MergeCandidateDeterminationStrategy(ABC):
181 def __init__(self):
182 self.clusterer: Optional["GreedyAgglomerativeClustering"] = None
184 """
185 Determines the indices of clusters which should be evaluated with regard to their merge costs
186 """
187 def set_clusterer(self, clusterer: "GreedyAgglomerativeClustering"):
188 """
189 Initialises the clusterer the strategy is applied to
190 :param clusterer: the clusterer
191 """
192 self.clusterer = clusterer
194 @abstractmethod
195 def iter_candidate_indices(self, wc: "GreedyAgglomerativeClustering.WrappedCluster", initial: bool,
196 merged_cluster_indices: Tuple[int, int] = None) -> Iterator[Union[int, "GreedyAgglomerativeClustering.ClusterMerge"]]:
197 """
198 :param wc: the wrapped cluster: the cluster for which to determine the cluster indices that are to be considered for
199 a potential merge
200 :param initial: whether we are computing the initial candidates (at the start of the clustering algorithm)
201 :param merged_cluster_indices: [for initial=False] the pair of cluster indices that were just joined to form the updated
202 cluster wc
203 :return: an iterator of cluster indices that should be evaluated as potential merge partners for wc (it may contain the
204 index of wc, which will be ignored)
205 """
206 pass
208 class MergeCandidateDeterminationStrategyDefault(MergeCandidateDeterminationStrategy):
209 def iter_candidate_indices(self, wc: "GreedyAgglomerativeClustering.WrappedCluster", initial: bool,
210 merged_cluster_indices: Tuple[int, int] = None) -> Iterator[Union[int, "GreedyAgglomerativeClustering.ClusterMerge"]]:
211 n = len(self.clusterer.wrapped_clusters)
212 if initial:
213 return range(wc.idx + 1, n)
214 else:
215 return range(n)