Source code for heat.cluster.kmedians
"""
Module Implementing the Kmedians Algorithm
"""
import heat as ht
from heat.cluster._kcluster import _KCluster
from heat.core.dndarray import DNDarray
from typing import Optional, Union, TypeVar
[docs]
class KMedians(_KCluster):
"""
K-Medians clustering algorithm [1].
Uses the Manhattan (City-block, :math:`L_1`) metric for distance calculations
Parameters
----------
n_clusters : int, optional, default: 8
The number of clusters to form as well as the number of centroids to generate.
init : str or DNDarray, default: ‘random’
Method for initialization:
- ‘k-medians++’ : selects initial cluster centers for the clustering in a smart way to speed up convergence [2].
- ‘random’: choose k observations (rows) at random from data for the initial centroids.
- 'batchparallel': initialize by using the batch parallel algorithm (see BatchParallelKMedians for more information).
- DNDarray: gives the initial centers, should be of Shape = (n_clusters, n_features)
max_iter : int, default: 300
Maximum number of iterations of the k-means algorithm for a single run.
tol : float, default: 1e-4
Relative tolerance with regards to inertia to declare convergence.
random_state : int
Determines random number generation for centroid initialization.
References
----------
[1] Hakimi, S., and O. Kariv. "An algorithmic approach to network location problems II: The p-medians." SIAM Journal on Applied Mathematics 37.3 (1979): 539-560.
"""
def __init__(
self,
n_clusters: int = 8,
init: Union[str, DNDarray] = "random",
max_iter: int = 300,
tol: float = 1e-4,
random_state: int = None,
):
if init == "kmedians++":
init = "probability_based"
super().__init__(
metric=lambda x, y: ht.spatial.distance.manhattan(x, y, expand=True),
n_clusters=n_clusters,
init=init,
max_iter=max_iter,
tol=tol,
random_state=random_state,
)
self._p = 1
[docs]
def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):
"""
Compute coordinates of new centroid as median of the data points in ``x`` that are assigned to it
Parameters
----------
x : DNDarray
Input data
matching_centroids : DNDarray
Array filled with indeces ``i`` indicating to which cluster ``ci`` each sample point in x is assigned
"""
new_cluster_centers = self._cluster_centers.copy()
for i in range(self.n_clusters):
# points in current cluster
selection = (matching_centroids == i).astype(ht.int64)
# Remove 0-element lines to avoid spoiling of median
assigned_points = x * selection
rows = assigned_points.abs().sum(axis=1) != 0
local = assigned_points.larray[rows.larray]
clean = ht.array(local, is_split=x.split)
clean.balance_()
# failsafe in case no point is assigned to this cluster
# draw a random datapoint to continue/restart
if clean.shape[0] == 0:
_, displ, _ = x.comm.counts_displs_shape(shape=x.shape, axis=0)
sample = ht.random.randint(0, x.shape[0]).item()
proc = 0
for p in range(x.comm.size):
if displ[p] > sample:
break
proc = p
xi = ht.zeros(x.shape[1], dtype=x.dtype)
if x.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
xi.comm.Bcast(xi, root=proc)
new_cluster_centers[i, :] = xi
else:
if clean.shape[0] <= ht.MPI_WORLD.size:
clean.resplit_(axis=None)
median = ht.median(clean, axis=0, keepdims=True)
new_cluster_centers[i : i + 1, :] = median
return new_cluster_centers
[docs]
def fit(self, x: DNDarray, oversampling: float = 2, iter_multiplier: float = 1):
"""
Computes the centroid of a k-medians clustering.
Parameters
----------
x : DNDarray
Training instances to cluster. Shape = (n_samples, n_features)
oversampling : float
oversampling factor used in the k-means|| initializiation of centroids
iter_multiplier : float
factor that increases the number of iterations used in the initialization of centroids
"""
# input sanitation
if not isinstance(x, ht.DNDarray):
raise ValueError(f"input needs to be a ht.DNDarray, but was {type(x)}")
# initialize the clustering
self._initialize_cluster_centers(x, oversampling, iter_multiplier)
self._n_iter = 0
# iteratively fit the points to the centroids
for epoch in range(self.max_iter):
# increment the iteration count
self._n_iter += 1
# determine the centroids
matching_centroids = self._assign_to_cluster(x)
# update the centroids
new_cluster_centers = self._update_centroids(x, matching_centroids)
# check whether centroid movement has converged
self._inertia = ((self._cluster_centers - new_cluster_centers) ** 2).sum()
self._cluster_centers = new_cluster_centers.copy()
if self.tol is not None and self._inertia <= self.tol:
break
self._labels = matching_centroids
return self