Source code for heat.core.base

"""Provides mixins for high-level algorithms, e.g. classifiers or clustering algorithms."""

import inspect
import json

from typing import Dict, List, TypeVar

from .dndarray import DNDarray

self = TypeVar("self")


[docs] class BaseEstimator: """ Abstract base class for all estimators, i.e. parametrized analysis algorithms, in Heat. Can be used as mixin. """
[docs] @classmethod def _parameter_names(cls) -> List[str]: """ Get the names of all parameters that can be set inside the constructor of the estimator. """ init = cls.__init__ if init is object.__init__: return [] # introspect the constructor arguments to find the model parameters init_signature = inspect.signature(init) # consider the constructor parameters excluding 'self' return [ p.name for p in init_signature.parameters.values() if p.name != "self" and p.kind == p.POSITIONAL_OR_KEYWORD ]
[docs] def get_params(self, deep: bool = True) -> Dict[str, object]: """ Get parameters for this estimator. Parameters ---------- deep : bool, default: True If ``True``, will return the parameters for this estimator and contained sub-objects that are estimators. """ params = {} for key in self._parameter_names(): value = getattr(self, key) if deep and hasattr(value, "get_params"): value = value.get_params() params[key] = value return params
[docs] def __repr__(self, indent: int = 1) -> str: """ Returns a printable representation of the object. Parameters ---------- indent : int, default: 1 Indicates the indentation for the top-level output. """ return f"{self.__class__.__name__}({json.dumps(self.get_params(), indent=4)})"
[docs] def set_params(self, **params: Dict[str, object]) -> self: """ Set the parameters of this estimator. The method works on simple estimators as well as on nested objects (such as pipelines). The latter have to be nested dictionaries. Parameters ---------- **params : dict[str, object] Estimator parameters to bet set. """ if not params: return self parameter_names = self._parameter_names() for key, value in params.items(): if key not in parameter_names: raise ValueError( f"Invalid parameter {key} for estimator {self}. Check the list of available parameters with `estimator.get_params().keys()`." ) if isinstance(value, dict): getattr(self, key).set_params(value) else: setattr(self, key, value) return self
[docs] class ClassificationMixin: """ Mixin for all classifiers in Heat. """
[docs] def fit(self, x: DNDarray, y: DNDarray): """ Fits the classification model. Parameters ---------- x : DNDarray Training instances to train on. Shape = (n_samples, n_features) y : DNDarray Class values to fit. Shape = (n_samples, ) """ raise NotImplementedError()
[docs] def fit_predict(self, x: DNDarray, y: DNDarray) -> DNDarray: """ Fits model and returns classes for each input sample Convenience method; equivalent to calling :func:`fit` followed by :func:`predict`. Parameters ---------- x : DNDarray Input data to be predicted. Shape = (n_samples, n_features) y : DNDarray Class values to fit. Shape = (n_samples, ) """ self.fit(x, y) return self.predict(x)
[docs] def predict(self, x: DNDarray) -> DNDarray: """ Predicts the class labels for each sample. Parameters ---------- x : DNDarray Values to predict the classes for. Shape = (n_samples, n_features) """ raise NotImplementedError()
[docs] class TransformMixin: """ Mixin for all transformations in Heat. """
[docs] def fit(self, x: DNDarray): """ Fits the transformation model. Parameters ---------- x : DNDarray Training instances to train on. Shape = (n_samples, n_features) """ raise NotImplementedError()
[docs] def fit_transform(self, x: DNDarray) -> DNDarray: """ Fits model and returns transformed data for each input sample Convenience method; equivalent to calling :func:`fit` followed by :func:`transform`. Parameters ---------- x : DNDarray Input data to be transformed. Shape = (n_samples, n_features) """ self.fit(x) return self.transform(x)
[docs] def transform(self, x: DNDarray) -> DNDarray: """ Transforms the input data. Parameters ---------- x : DNDarray Values to transform. Shape = (n_samples, n_features) """ raise NotImplementedError()
[docs] class ClusteringMixin: """ Clustering mixin for all clusterers in Heat. """
[docs] def fit(self, x: DNDarray): """ Computes the clustering. Parameters ---------- x : DNDarray Training instances to cluster. Shape = (n_samples, n_features) """ raise NotImplementedError()
[docs] def fit_predict(self, x: DNDarray) -> DNDarray: """ Compute clusters and returns the predicted cluster assignment for each sample. Returns index of the cluster each sample belongs to. Convenience method; equivalent to calling :func:`fit` followed by :func:`predict`. Parameters ---------- x : DNDarray Input data to be clustered. Shape = (n_samples, n_features) """ self.fit(x) return self.predict(x)
[docs] class RegressionMixin: """ Mixin for all regression estimators in Heat. """
[docs] def fit(self, x: DNDarray, y: DNDarray): """ Fits the regression model. Parameters ---------- x : DNDarray Training instances to train on. Shape = (n_samples, n_features) y : DNDarray Continuous values to fit. Shape = (n_samples,) """ raise NotImplementedError()
[docs] def fit_predict(self, x: DNDarray, y: DNDarray) -> DNDarray: """ Fits model and returns regression predictions for each input sample Convenience method; equivalent to calling :func:`fit` followed by :func:`predict`. Parameters ---------- x : DNDarray Input data to be predicted. Shape = (n_samples, n_features) y : DNDarray Continuous values to fit. Shape = (n_samples,) """ self.fit(x, y) return self.predict(x)
[docs] def predict(self, x: DNDarray) -> DNDarray: """ Predicts the continuous labels for each sample. Parameters ---------- x : DNDarray Values to let the model predict. Shape = (n_samples, n_features) """ raise NotImplementedError()
[docs] def is_classifier(estimator: object) -> bool: """ Return ``True`` if the given estimator is a classifier, ``False`` otherwise. Parameters ---------- estimator : object Estimator object to test. """ return isinstance(estimator, ClassificationMixin)
[docs] def is_transformer(estimator: object) -> bool: """ Return ``True`` if the given estimator is a transformer, ``False`` otherwise. Parameters ---------- estimator : object Estimator object to test. """ return isinstance(estimator, TransformMixin)
[docs] def is_estimator(estimator: object) -> bool: """ Return ``True`` if the given estimator is an estimator, ``False`` otherwise. Parameters ---------- estimator : object Estimator object to test. """ return isinstance(estimator, BaseEstimator)
[docs] def is_clusterer(estimator: object) -> bool: """ Return ``True`` if the given estimator is a clusterer, ``False`` otherwise. Parameters ---------- estimator : object Estimator object to test. """ return isinstance(estimator, ClusteringMixin)
[docs] def is_regressor(estimator: object) -> bool: """ Return ``True`` if the given estimator is a regressor, ``False`` otherwise. Parameters ---------- estimator : object Estimator object to test. """ return isinstance(estimator, RegressionMixin)