Source code for heat.cluster.kmedoids

"""
Module Implementing the Kmedoids Algorithm
"""

import heat as ht
from heat.cluster._kcluster import _KCluster
from heat.core.dndarray import DNDarray
from typing import Optional, Union, TypeVar


[docs] class KMedoids(_KCluster): """ Kmedoids with the Manhattan distance as fixed metric, calculating the median of the assigned cluster points as new cluster center and snapping the centroid to the the nearest datapoint afterwards. This is not the original implementation of k-medoids using PAM as originally proposed by in [1]. Parameters ---------- n_clusters : int, optional, default: 8 The number of clusters to form as well as the number of centroids to generate. init : str or DNDarray, default: ‘random’ Method for initialization: - ‘k-medoids++’ : selects initial cluster centers for the clustering in a smart way to speed up convergence [2]. - ‘random’: choose k observations (rows) at random from data for the initial centroids. - DNDarray: gives the initial centers, should be of Shape = (n_clusters, n_features) max_iter : int, default: 300 Maximum number of iterations of the algorithm for a single run. random_state : int Determines random number generation for centroid initialization. References ---------- [1] Kaufman, L. and Rousseeuw, P.J. (1987), Clustering by means of Medoids, in Statistical Data Analysis Based on the L1 Norm and Related Methods, edited by Y. Dodge, North-Holland, 405416. """ def __init__( self, n_clusters: int = 8, init: Union[str, DNDarray] = "random", max_iter: int = 300, random_state: int = None, ): if init == "kmedoids++": init = "probability_based" super().__init__( metric=lambda x, y: ht.spatial.distance.manhattan(x, y, expand=True), n_clusters=n_clusters, init=init, max_iter=max_iter, tol=0.0, random_state=random_state, )
[docs] def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray): """ Compute new centroid ``ci`` as closest sample to the median of the data points in ``x`` that are assigned to ``ci`` Parameters ---------- x : DNDarray Input data matching_centroids : DNDarray Array filled with indeces ``i`` indicating to which cluster ``ci`` each sample point in ``x`` is assigned """ new_cluster_centers = self._cluster_centers.copy() for i in range(self.n_clusters): # points in current cluster selection = (matching_centroids == i).astype(ht.int64) # Remove 0-element lines to avoid spoiling of median assigned_points = x * selection rows = (assigned_points.abs()).sum(axis=1) != 0 local = assigned_points.larray[rows._DNDarray__array] clean = ht.array(local, is_split=x.split) clean.balance_() # failsafe in case no point is assigned to this cluster # draw a random datapoint to continue/restart if clean.shape[0] == 0: _, displ, _ = x.comm.counts_displs_shape(shape=x.shape, axis=0) sample = ht.random.randint(0, x.shape[0]).item() proc = 0 for p in range(x.comm.size): if displ[p] > sample: break proc = p xi = ht.zeros(x.shape[1], dtype=x.dtype) if x.comm.rank == proc: idx = sample - displ[proc] xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm) xi.comm.Bcast(xi, root=proc) new_cluster_centers[i, :] = xi else: if clean.shape[0] <= ht.MPI_WORLD.size: clean.resplit_(axis=None) median = ht.median(clean, axis=0, keepdims=True) dist = self._metric(x, median) _, displ, _ = x.comm.counts_displs_shape(shape=x.shape, axis=0) idx = dist.argmin(axis=0, keepdims=False).item() proc = 0 for p in range(x.comm.size): if displ[p] > idx: break proc = p closest_point = ht.zeros(x.shape[1], dtype=x.dtype) if x.comm.rank == proc: lidx = idx - displ[proc] closest_point = ht.array(x.lloc[lidx, :], device=x.device, comm=x.comm) closest_point.comm.Bcast(closest_point, root=proc) new_cluster_centers[i, :] = closest_point return new_cluster_centers
[docs] def fit(self, x: DNDarray, oversampling: float = 2, iter_multiplier: float = 1): """ Computes the centroid of a k-medoids clustering. Parameters ---------- x : DNDarray Training instances to cluster. Shape = (n_samples, n_features) oversampling : float oversampling factor used in the k-means|| initializiation of centroids iter_multiplier : float factor that increases the number of iterations used in the initialization of centroids """ # input sanitation if not isinstance(x, DNDarray): raise ValueError(f"input needs to be a ht.DNDarray, but was {type(x)}") # initialize the clustering self._initialize_cluster_centers(x, oversampling, iter_multiplier) self._n_iter = 0 # iteratively fit the points to the centroids for epoch in range(self.max_iter): # increment the iteration count self._n_iter += 1 # determine the centroids matching_centroids = self._assign_to_cluster(x) # update the centroids new_cluster_centers = self._update_centroids(x, matching_centroids) # check whether centroid movement has converged if ht.equal(self._cluster_centers, new_cluster_centers): break self._cluster_centers = new_cluster_centers.copy() self._labels = matching_centroids return self