Coverage for src/sensai/clustering/sklearn_clustering.py: 67%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import numpy as np 

2from typing_extensions import Protocol 

3 

4from . import EuclideanClusterer 

5 

6 

7class SkLearnClustererProtocol(Protocol): 

8 """ 

9 Only used for type hints, do not instantiate 

10 """ 

11 def fit(self, x: np.ndarray): ... 

12 

13 labels_: np.ndarray 

14 

15 

16class SkLearnEuclideanClusterer(EuclideanClusterer): 

17 """ 

18 Wrapper around an sklearn-type clustering algorithm 

19 

20 :param clusterer: a clusterer object compatible the sklearn API 

21 :param noise_label: label that is associated with the noise cluster or None 

22 :param min_cluster_size: if not None, clusters below this size will be labeled as noise 

23 :param max_cluster_size: if not None, clusters above this size will be labeled as noise 

24 """ 

25 

26 def __init__(self, clusterer: SkLearnClustererProtocol, noise_label=-1, 

27 min_cluster_size: int = None, max_cluster_size: int = None): 

28 super().__init__(noise_label=noise_label, min_cluster_size=min_cluster_size, max_cluster_size=max_cluster_size) 

29 self.clusterer = clusterer 

30 

31 def _compute_labels(self, x: np.ndarray): 

32 self.clusterer.fit(x) 

33 return self.clusterer.labels_ 

34 

35 def __str__(self): 

36 return f"{super().__str__()}_{self.clusterer.__class__.__name__}"