heat.regression.lasso
Implementation of the LASSO regression
Module Contents
- class Lasso(lam: float | None = 0.1, max_iter: int | None = 100, tol: float | None = 1e-06)
Bases:
heat.RegressionMixin
,heat.BaseEstimator
``Least absolute shrinkage and selection operator``(LASSO), a linear model with L1 regularization. The optimization objective for Lasso is:
\[E(w) = \frac{1}{2 m} ||y - Xw||^2_2 + \lambda ||w\_||_1\]with
\[w\_=(w_1,w_2,...,w_n), w=(w_0,w_1,w_2,...,w_n),\]\[y \in M(m \times 1), w \in M(n \times 1), X \in M(m \times n)\]- Parameters:
lam (float, optional) – Constant that multiplies the L1 term. Default value: 0.1
lam = 0.
is equivalent to an ordinary least square (OLS). For numerical reasons, usinglam = 0.,
with theLasso
object is not advised.max_iter (int, optional) – The maximum number of iterations. Default value: 100
tol (float, optional. Default value: 1e-8) – The tolerance for the optimization.
- Variables:
__theta (array, shape (n_features + 1,), first element is the interception parameter vector w.)
coef (array, shape (n_features,) | (n_targets, n_features)) – parameter vector (w in the cost function formula)
intercept (float | array, shape (n_targets,)) – independent term in decision function.
n_iter (int or None | array-like, shape (n_targets,)) – number of iterations run by the coordinate descent solver to reach the specified tolerance.
Examples
>>> X = ht.random.randn(10, 4, split=0) >>> y = ht.random.randn(10,1, split=0) >>> estimator = ht.regression.lasso.Lasso(max_iter=100, tol=None) >>> estimator.fit(X, y)
- soft_threshold(rho: heat.core.dndarray.DNDarray) heat.core.dndarray.DNDarray | float
Soft threshold operator
- rmse(gt: heat.core.dndarray.DNDarray, yest: heat.core.dndarray.DNDarray) heat.core.dndarray.DNDarray
Root mean square error (RMSE)
- fit(x: heat.core.dndarray.DNDarray, y: heat.core.dndarray.DNDarray) None
Fit lasso model with coordinate descent
- predict(x: heat.core.dndarray.DNDarray) heat.core.dndarray.DNDarray
Apply lasso model to input data. First row data corresponds to interception
- Parameters:
x (DNDarray) – Input data, Shape = (n_samples, n_features)