heat.classification.kneighborsclassifier
Implements the k-nearest neighbors (kNN) classifier
Module Contents
- class KNeighborsClassifier(n_neighbors: int = 5, effective_metric_: Callable = None)
Bases:
heat.BaseEstimator
,heat.ClassificationMixin
Implementation of the k-nearest-neighbors Algorithm [1].
This algorithm predicts labels to data vectors by using an labeled training dataset as reference. The input vector to be predicted is compared to the training vectors by calculating the Euclidean distance between each of them. A majority vote of the k-nearest, i.e. closest or smallest distanced, training vectors labels is selected as predicted class.
- Parameters:
n_neighbors (int, optional, default: 5) – Number of neighbours to consider when choosing label.
effective_metric (Callable, optional) – The distance function used to identify the nearest neighbors, defaults to the Euclidean distance.
References
[1] T. Cover and P. Hart, “Nearest Neighbor Pattern Classification,” in IEEE Transactions on Information Theory, vol. 13, no. 1, pp. 21-27, January 1967, doi: 10.1109/TIT.1967.1053964.
- one_hot_encoding(x: heat.core.dndarray.DNDarray) heat.core.dndarray.DNDarray
One-hot-encodes the passed vector or single-column matrix.
- Parameters:
x (DNDarray) – The data to be encoded.
- fit(x: heat.core.dndarray.DNDarray, y: heat.core.dndarray.DNDarray)
Fit the k-nearest neighbors classifier from the training dataset.
- Parameters:
x (DNDarray) – Labeled training vectors used for comparison in predictions, Shape=(n_samples, n_features).
y (DNDarray) – Corresponding labels for the training feature vectors. Must have the same number of samples as
x
. Shape=(n_samples) if integral labels or Shape=(n_samples, n_classes) if one-hot-encoded.
- Raises:
TypeError – If
x
ory
are not DNDarrays.ValueError – If
x
andy
shapes mismatch or are not two-dimensional matrices.
Examples
>>> samples = ht.rand(10, 3) >>> knn = KNeighborsClassifier(n_neighbors=1) >>> knn.fit(samples)
- predict(x: heat.core.dndarray.DNDarray) heat.core.dndarray.DNDarray
Predict the class labels for the provided data.
- Parameters:
x (DNDarray) – The test samples.