Coverage for src/sensai/clustering/sklearn_clustering.py: 67%
15 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import numpy as np
2from typing_extensions import Protocol
4from . import EuclideanClusterer
7class SkLearnClustererProtocol(Protocol):
8 """
9 Only used for type hints, do not instantiate
10 """
11 def fit(self, x: np.ndarray): ...
13 labels_: np.ndarray
16class SkLearnEuclideanClusterer(EuclideanClusterer):
17 """
18 Wrapper around an sklearn-type clustering algorithm
20 :param clusterer: a clusterer object compatible the sklearn API
21 :param noise_label: label that is associated with the noise cluster or None
22 :param min_cluster_size: if not None, clusters below this size will be labeled as noise
23 :param max_cluster_size: if not None, clusters above this size will be labeled as noise
24 """
26 def __init__(self, clusterer: SkLearnClustererProtocol, noise_label=-1,
27 min_cluster_size: int = None, max_cluster_size: int = None):
28 super().__init__(noise_label=noise_label, min_cluster_size=min_cluster_size, max_cluster_size=max_cluster_size)
29 self.clusterer = clusterer
31 def _compute_labels(self, x: np.ndarray):
32 self.clusterer.fit(x)
33 return self.clusterer.labels_
35 def __str__(self):
36 return f"{super().__str__()}_{self.clusterer.__class__.__name__}"