"""
Base-module for k-clustering algorithms
"""
import heat as ht
import torch
from heat.cluster.batchparallelclustering import _kmex
from typing import Optional, Union, Callable
from heat.core.dndarray import DNDarray
import warnings
[docs]
class _KCluster(ht.ClusteringMixin, ht.BaseEstimator):
"""
Base class for k-statistics clustering algorithms (kmeans, kmedians, kmedoids).
The clusters are represented by centroids ci (we use the term from kmeans for simplicity)
Parameters
----------
metric : function
One of the distance metrics in ht.spatial.distance. Needs to be passed as lambda function to take only two arrays as input
n_clusters : int
The number of clusters to form as well as the number of centroids to generate.
init : str or DNDarray, default: ‘random’
Method for initialization:
- ‘probability_based’ : selects initial cluster centers for the clustering in a smart way to speed up convergence (k-means++)
- ‘random’: choose k observations (rows) at random from data for the initial centroids.
- 'batchparallel': use the batch parallel algorithm to initialize the centroids, only available for split=0 and KMeans or KMedians
- ``DNDarray``: gives the initial centers, should be of Shape = (n_clusters, n_features)
max_iter : int
Maximum number of iterations for a single run.
tol : float, default: 1e-4
Relative tolerance with regards to inertia to declare convergence.
random_state : int
Determines random number generation for centroid initialization.
"""
def __init__(
self,
metric: Callable,
n_clusters: int,
init: Union[str, DNDarray],
max_iter: int,
tol: float,
random_state: int,
): # noqa: D107
self.n_clusters = n_clusters
self.init = init
self.max_iter = max_iter
self.tol = tol
self.random_state = random_state
# in-place properties
self._metric = metric
self._cluster_centers = None
self._functional_value = None
self._labels = None
self._inertia = None
self._n_iter = None
self._p = None
@property
def cluster_centers_(self) -> DNDarray:
"""
Returns the coordinates of the cluster centers.
If the algorithm stops before fully converging (see ``tol`` and ``max_iter``),
these will not be consistent with :func:`labels_`.
"""
return self._cluster_centers
@property
def labels_(self) -> DNDarray:
"""
Returns the labels of each point
"""
return self._labels
@property
def inertia_(self) -> float:
"""
Returns the sum of squared distances of samples to their closest cluster center.
"""
return self._inertia
@property
def n_iter_(self) -> int:
"""
Returns the number of iterations run.
"""
return self._n_iter
@property
def functional_value_(self) -> DNDarray:
"""
Returns the K-Clustering functional value of the clustering algorithm
"""
return self._functional_value
[docs]
def _initialize_cluster_centers(self, x: DNDarray, oversampling: float, iter_multiplier: float):
"""
Initializes the K-Means centroids.
Parameters
----------
x : DNDarray
The data to initialize the clusters for. Shape = (n_samples, n_features)
oversampling : float
oversampling factor used in the k-means|| initializiation of centroids
iter_multiplier : float
factor that increases the number of iterations used in the initialization of centroids
"""
# input sanitation
if not isinstance(x, DNDarray):
raise ValueError(f"Input x needs to be a ht.DNDarray, but was {type(x)}")
if oversampling < 2:
raise ValueError(f"Oversampling factor should be at least 2, but was {oversampling}")
if iter_multiplier < 1:
raise ValueError(
f"Iteration multiplier should be at least 1, but was {iter_multiplier}"
)
# always initialize the random state
if self.random_state is not None:
ht.random.seed(self.random_state)
# initialize the centroids by randomly picking some of the points
if self.init == "random":
idx = ht.random.randint(0, x.shape[0] - 1, size=(self.n_clusters,), split=None)
centroids = x[idx, :]
self._cluster_centers = centroids if x.split == 1 else centroids.resplit(None)
# directly passed centroids
elif isinstance(self.init, DNDarray):
if len(self.init.shape) != 2:
raise ValueError(
f"passed centroids need to be two-dimensional, but are {len(self.init)}"
)
if self.init.shape[0] != self.n_clusters or self.init.shape[1] != x.shape[1]:
raise ValueError("passed centroids do not match cluster count or data shape")
self._cluster_centers = self.init.resplit(None)
# Parallelized centroid guessing using the k-means|| algorithm
elif self.init == "probability_based":
# First, check along which axis the data is sliced
if x.split is None or x.split == 0:
# Define a random integer serving as a label to pick the first centroid randomly
init_idx = ht.random.randint(0, x.shape[0] - 1).item()
# Randomly select first centroid and organize it as a tensor, in order to use the function cdist later.
# This tensor will be filled continously in the proceeding of this function
# We assume that the centroids fit into the memory of a single GPU
init_centroids = ht.expand_dims(x[init_idx, :].resplit(None), axis=0)
# Calculate the initial cost of the clustering after the first centroid selection
# and use it as an indicator for the order of magnitude for the number of necessary iterations
init_distance = ht.spatial.distance.cdist(
x, init_centroids, quadratic_expansion=True
)
# --> init_distance calculates the Euclidean distance between data points x and initial centroids
# output format: tensor
init_min_distance = init_distance.min(axis=1)
# --> Pick the minimal distance of the data points to each centroid
# output format: vector
init_cost = init_min_distance.sum()
# --> Now calculate the cost
# output format: scalar
# Iteratively fill the tensor storing the centroids
num_iters = max(
1, int(iter_multiplier * ht.log(init_cost))
) # ensure at least one iteration
centroids = self._centroid_sampling_helper(
x, init_centroids, oversampling, num_iters
)
# Check if enough centroids were found; increase oversampling factor automatically if neccessary
if centroids.shape[0] <= self.n_clusters:
warnings.warn(
f"Oversampling={oversampling} is too low for data set."
"Increasing it by factor 10 automatically. And restarting centroid initialization.",
UserWarning,
)
oversampling = 10 * oversampling
centroids = self._centroid_sampling_helper(
x, init_centroids, oversampling, num_iters
)
# Raise ValueError if still not enough centroids are found
if centroids.shape[0] <= self.n_clusters:
raise ValueError(
f"The parameter oversampling={oversampling} and/or iter_multiplier={iter_multiplier} "
"are chosen too small for the initialization of cluster centers."
)
# Evaluate the distance between data and the final set of centroids for the initialization
final_distance = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True)
# For each data point in x, find the index of the centroid that is closest
final_idx = ht.argmin(final_distance, axis=1)
# Introduce weights, i.e., the number of data points closest to each centroid
# (count how often the same index in final_idx occurs)
weights = ht.zeros(centroids.shape[0], split=centroids.split)
for i in range(centroids.shape[0]):
weights[i] = ht.sum(final_idx == i)
# Recluster the oversampled centroids using standard k-means ++ (here we use the
# already implemented version in torch)
centroids = centroids.resplit(None)
centroids = centroids.larray
weights = weights.resplit(None)
weights = weights.larray
# --> first transform relevant arrays into torch tensors
if ht.MPI_WORLD.rank == 0:
batch_kmeans = _kmex(
centroids,
p=2,
n_clusters=self.n_clusters,
init="++",
max_iter=self.max_iter,
tol=self.tol,
random_state=None,
weights=weights,
)
# --> apply standard k-means ++
# Note: as we only recluster the centroids for initialization with standard k-means ++,
# this list of centroids can also be used to initialize k-medians and k-medoids
reclustered_centroids = batch_kmeans[0]
# --> access the reclustered centroids
else:
# ensure that all processes have the same data
reclustered_centroids = torch.zeros(
(self.n_clusters, centroids.shape[1]),
dtype=x.dtype.torch_type(),
device=centroids.device,
)
# --> tensor with zeros that has the same size as reclustered centroids, in order to to
# allocate memory with the correct type in all processes(necessary for broadcast)
ht.MPI_WORLD.Bcast(
reclustered_centroids, root=0
) # by default it is broadcasted from process 0
reclustered_centroids = ht.array(reclustered_centroids, split=None)
# --> transform back to DNDarray
self._cluster_centers = reclustered_centroids
# --> final result for initialized cluster centers
else:
raise NotImplementedError("Not implemented for other splitting-axes")
elif self.init == "batchparallel":
if x.split == 0:
if self._p == 2:
batch_parallel_clusterer = ht.cluster.BatchParallelKMeans(
n_clusters=self.n_clusters,
init="k-means++",
max_iter=100,
random_state=self.random_state,
)
elif self._p == 1:
batch_parallel_clusterer = ht.cluster.BatchParallelKMedians(
n_clusters=self.n_clusters,
init="k-medians++",
max_iter=100,
random_state=self.random_state,
)
else:
raise ValueError(
"Batch parallel initialization only implemented for KMeans and KMedians"
)
batch_parallel_clusterer.fit(x)
self._cluster_centers = batch_parallel_clusterer.cluster_centers_
else:
raise NotImplementedError(
f"Batch parallel initalization only implemented for split = 0, but split was {x.split}"
)
else:
raise ValueError(
'init needs to be one of "random", ht.DNDarray, "kmeans++", or "batchparallel", but was {}'.format(
self.init
)
)
[docs]
def _centroid_sampling_helper(
self, x: DNDarray, centroids: DNDarray, oversampling: float, num_iters: int
):
"""
Helper function for the k-means|| initialization of centroids. Samples new centroids based on a probability
distribution derived from the distance of data points to the current set of centroids.
Parameters
----------
x : DNDarray
The data to initialize the clusters for. Shape = (n_samples, n_features)
centroids : DNDarray
The initial set of centroids
oversampling : float
oversampling factor used in the k-means|| initializiation of centroids
num_iters : float
number of iterations used in the initialization of centroids
"""
# Pre-allocate receive buffer for later communication
world_size = ht.MPI_WORLD.size
max_total = world_size * x.shape[0]
recv_buf = torch.empty(
(max_total, x.shape[1]), dtype=x.larray.dtype, device=x.larray.device
)
for _ in range(num_iters):
# Calculate the distance between data points and the current set of centroids
distance = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True)
min_distance = distance.min(axis=1)
# Increase numerical stability for many duplicate data points
eps = ht.array(1e-12, split=None, device=x.device)
min_dist_sum = ht.maximum(min_distance.sum(), eps)
# Sample each point in the data to a new set of centroids
prob = oversampling * min_distance / min_dist_sum
# --> probability distribution with oversampling factor
# output format: vector
# Define a list of random, uniformly distributed probabilities
sample = ht.random.rand(x.shape[0], split=x.split)
idx = ht.where(sample <= prob)
# --> choose indices to sample the data according to prob
# output format: vector
# Extract the local candidate centroids on each process
local_candidates = x[idx].larray
# ensure correct shape of no candidate is found (required for communication)
local_candidates = local_candidates.reshape(-1, x.shape[1])
# number of candidates
n_local = local_candidates.shape[0]
# Gather the number of local candidates from each MPI rank into a list
counts = ht.MPI_WORLD.allgather(n_local)
# Build a list of starting offsets so each rank’s block lands in the correct slice of the buffer
displs = [0]
for c in counts[:-1]:
displs.append(displs[-1] + c)
# Compute the total number of candidates across all ranks
total = displs[-1] + counts[-1]
# Take only the first 'total' rows from the preallocated receive buffer
buffer = recv_buf[:total]
# Gather all local candidates into the receive buffer
ht.MPI_WORLD.Allgatherv(local_candidates, (buffer, counts, displs), recv_axis=0)
new_candidates = ht.array(buffer, split=None, device=x.device)
# --> pick the data points that are identified as possible centroids
# output format: vector
centroids = ht.row_stack((centroids, new_candidates))
# --> stack the data points with these indices to the DNDarray of centroids
# output format: tensor
return centroids
[docs]
def _assign_to_cluster(self, x: DNDarray, eval_functional_value: bool = False):
"""
Assigns the passed data points to the centroids based on the respective metric
Parameters
----------
x : DNDarray
Data points, Shape = (n_samples, n_features)
eval_functional_value : bool, default: False
If True, the current K-Clustering functional value of the clustering algorithm is evaluated
"""
# calculate the distance matrix and determine the closest centroid
distances = self._metric(x, self._cluster_centers)
matching_centroids = distances.argmin(axis=1, keepdims=True)
if eval_functional_value:
self._functional_value = ht.norm(distances.min(axis=1), ord=self._p) ** self._p
return matching_centroids
[docs]
def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):
"""
The Update strategy is algorithm specific (e.g. calculate mean of assigned points for kmeans, median for kmedians, etc.)
Parameters
----------
x : DNDarray
Input Data
matching_centroids: DNDarray
Index array of assigned centroids
"""
raise NotImplementedError()
[docs]
def fit(self, x: DNDarray):
"""
Computes the centroid of the clustering algorithm to fit the data ``x``. The full pipeline is algorithm specific.
Parameters
----------
x : DNDarray
Training instances to cluster. Shape = (n_samples, n_features)
"""
raise NotImplementedError()
[docs]
def predict(self, x: DNDarray):
"""
Predict the closest cluster each sample in ``x`` belongs to.
In the vector quantization literature, :func:`cluster_centers_` is called the code book and each value returned by
predict is the index of the closest code in the code book.
Parameters
----------
x : DNDarray
New data to predict. Shape = (n_samples, n_features)
"""
# input sanitation
if not isinstance(x, DNDarray):
raise ValueError(f"input needs to be a ht.DNDarray, but was {type(x)}")
# determine the centroids
return self._assign_to_cluster(x, eval_functional_value=True)