"""
Implements polar decomposition (PD)
"""
import numpy as np
import collections
import torch
from typing import Type, Callable, Dict, Any, TypeVar, Union, Tuple
from ..communication import MPICommunication, MPI
from ..dndarray import DNDarray
from .. import factories
from .. import types
from . import matrix_norm, vector_norm, matmul, qr, solve_triangular
from .basics import _estimate_largest_singularvalue, condest
from ..indexing import where
from ..random import randn
from ..devices import Device
from ..manipulations import vstack, hstack, concatenate, diag, balance
from ..exponential import sqrt
from .. import statistics
from scipy.special import ellipj
from scipy.special import ellipkm1
__all__ = ["polar"]
def _zolopd_n_iterations(r: int, kappa: float) -> int:
"""
Returns the number of iterations required in the Zolotarev-PD algorithm.
See the Table 3.1 in: Nakatsukasa, Y., & Freund, R. W. (2016). Computing Fundamental Matrix Decompositions Accurately via the Matrix Sign Function in Two Iterations: The Power of Zolotarev's Functions. SIAM Review, 58(3), DOI: https://doi.org/10.1137/140990334
Inputs are `r` and `kappa` (named as in the paper), and the output is the number of iterations.
"""
if kappa <= 1e2:
its = [4, 3, 2, 2, 2, 2, 2, 2]
elif kappa <= 1e3:
its = [3, 3, 2, 2, 2, 2, 2, 2]
elif kappa <= 1e5:
its = [5, 3, 3, 3, 2, 2, 2, 2]
elif kappa <= 1e7:
its = [5, 4, 3, 3, 3, 2, 2, 2]
else:
its = [6, 4, 3, 3, 3, 3, 3, 2]
return its[r - 1]
def _compute_zolotarev_coefficients(
r: int, ell: float, device: str, dtype: types.datatype = types.float64
) -> Tuple[DNDarray, DNDarray, DNDarray]:
"""
Computes c=(c_i)_i defined in equation (3.4), as well as a=(a_j)_j and Mhat defined in formulas (4.2)/(4.3) of the paper Nakatsukasa, Y., & Freund, R. W. (2016). Computing the polar decomposition with applications. SIAM Review, 58(3), DOI: https://doi.org/10.1137/140990334.
Evaluations of the respective complete elliptic integral of the first kind and the Jacobi elliptic functions are imported from SciPy.
Inputs are `r` and `ell` (named as in the paper), as well as the Heat data type `dtype` of the output (required for reasons of consistency).
Output is a tupe containing the vectors `a` and `c` as DNDarrays and `Mhat`.
"""
uu = np.arange(1, 2 * r + 1) * ellipkm1(ell**2) / (2 * r + 1)
ellipfcts = np.asarray(ellipj(uu, 1 - ell**2)[:2])
cc = ell**2 * ellipfcts[0, :] ** 2 / ellipfcts[1, :] ** 2
aa = np.zeros(r)
Mhat = 1
for j in range(1, r + 1):
p1 = 1
p2 = 1
for k in range(1, r + 1):
p1 *= cc[2 * j - 2] - cc[2 * k - 1]
if k != j:
p2 *= cc[2 * j - 2] - cc[2 * k - 2]
aa[j - 1] = -p1 / p2
Mhat *= (1 + cc[2 * j - 2]) / (1 + cc[2 * j - 1])
return (
factories.array(cc, dtype=dtype, split=None, device=device),
factories.array(aa, dtype=dtype, split=None, device=device),
factories.array(Mhat, dtype=dtype, split=None, device=device),
)
def _in_place_qr_with_q_only(A: DNDarray, procs_to_merge: int = 2) -> None:
r"""
Input A and procs_to_merge are as in heat.linalg.qr; difference it that this routine modified A in place and replaces it with Q.
"""
if not A.is_distributed() or A.split < A.ndim - 2:
# handle the case of a single process or split=None: just PyTorch QR
# difference to heat.linalg.qr: we only return Q and put it directly in place of A
A.larray, R = torch.linalg.qr(A.larray, mode="reduced")
del R
elif A.split == A.ndim - 1:
# handle the case that A is split along the columns
# unlike in heat.linalg.qr, we know by assumption of Zolo-PD that A has at least as many rows as columns
nprocs = A.comm.size
with torch.no_grad():
for i in range(nprocs):
# this loop goes through all the column-blocks (i.e. local arrays) of the matrix
# this corresponds to the loop over all columns in classical Gram-Schmidt
A_lshapes = A.lshape_map
if i < nprocs - 1:
if A.comm.rank > i:
Q_buf = torch.zeros(
tuple(A_lshapes[i, :]),
dtype=A.larray.dtype,
device=A.device.torch_device,
)
color = 0 if A.comm.rank < i else 1
sub_comm = A.comm.Split(color, A.comm.rank)
if A.comm.rank == i:
# orthogonalize the current block of columns by utilizing PyTorch QR
Q, R = torch.linalg.qr(A.larray, mode="reduced")
del R
A.larray[...] = Q
del Q
if i < nprocs - 1:
Q_buf = A.larray
if i < nprocs - 1 and A.comm.rank >= i:
sub_comm.Bcast(Q_buf, root=0)
if A.comm.rank > i:
# subtract the contribution of the current block of columns from the remaining columns
R_loc = torch.transpose(Q_buf, -2, -1) @ A.larray
A.larray -= Q_buf @ R_loc
del R_loc, Q_buf
else:
A, r = qr(A)
del r
[docs]
def polar(
A: DNDarray,
r: int = None,
calcH: bool = True,
condition_estimate: float = 1.0e16,
silent: bool = True,
r_max: int = 8,
) -> Tuple[DNDarray, DNDarray]:
"""
Computes the so-called polar decomposition of the input 2D DNDarray ``A``, i.e., it returns the orthogonal matrix ``U`` and the symmetric, positive definite
matrix ``H`` such that ``A = U @ H``.
Input
-----
A : ht.DNDarray,
The input matrix for which the polar decomposition is computed;
must be two-dimensional, of data type float32 or float64, and must have at least as many rows as columns.
r : int, optional, default: None
The parameter r used in the Zolotarev-PD algorithm; if provided, must be an integer between 1 and 8 that divides the number of MPI processes.
Higher values of r lead to faster convergence, but memory consumption is proportional to r.
If not provided, the largest 1 <= r <= r_max that divides the number of MPI processes is chosen.
calcH : bool, optional, default: True
If True, the function returns the symmetric, positive definite matrix H. If False, only the orthogonal matrix U is returned.
condition_estimate : float, optional, default: 1.e16.
This argument allows to provide an estimate for the condition number of the input matrix ``A``, if such estimate is already known.
If a positive number greater than 1., this value is used as an estimate for the condition number of A.
If smaller or equal than 1., the condition number is estimated internally.
The default value of 1.e16 is the worst case scenario considered in [1].
silent : bool, optional, default: True
If True, the function does not print any output. If False, some information is printed during the computation.
r_max : int, optional, default: 8
See the description of r for the meaning; r_max is only taken into account if r is not provided.
Notes
-----
The implementation follows Algorithm 5.1 in Reference [1]; however, instead of switching from QR to Cholesky decomposition depending on the condition number,
we stick to QR decomposition in all iterations.
References
----------
[1] Nakatsukasa, Y., & Freund, R. W. (2016). Computing Fundamental Matrix Decompositions Accurately via the Matrix Sign Function in Two Iterations: The Power of Zolotarev's Functions. SIAM Review, 58(3), DOI: https://doi.org/10.1137/140990334.
"""
# check whether input is DNDarray of correct shape
if not isinstance(A, DNDarray):
raise TypeError(f"Input ``A`` needs to be a DNDarray but is {type(A)}.")
if not A.ndim == 2:
raise ValueError(f"Input ``A`` needs to be a 2D DNDarray, but its dimension is {A.ndim}.")
if A.shape[0] < A.shape[1]:
raise ValueError(
f"Input ``A`` must have at least as many rows as columns, but has shape {A.shape}."
)
# check if A is a real floating point matrix and choose tolerances tol accordingly
if A.dtype == types.float32:
tol = 1.19e-7
elif A.dtype == types.float64:
tol = 2.22e-16
else:
raise TypeError(
f"Input ``A`` must be of data type float32 or float64 but has data type {A.dtype}"
)
# check if input for r is reasonable
if r is not None:
if not isinstance(r, int) or r < 1 or r > 8:
raise ValueError(
f"If specified, input ``r`` must be an integer between 1 and 8, but is {r} of data type {type(r)}."
)
if A.is_distributed() and (A.comm.size % r != 0 or A.comm.size == r):
raise ValueError(
f"If specified, input ``r`` must be a non-trivial divisor of the number MPI processes , but r={r} and A.comm.size={A.comm.size}."
)
else:
if not isinstance(r_max, int) or r_max < 1 or r_max > 8:
raise ValueError(
f"If specified, input ``r_max`` must be an integer between 1 and 8, but is {r_max} of data type {type(r_max)}."
)
for i in range(r_max, 0, -1):
if A.comm.size % i == 0 and A.comm.size // i > 1:
r = i
break
if not silent:
if A.comm.rank == 0:
print(f"Automatically chosen r={r} (r_max = {r_max}, {A.comm.size} processes).")
# check if input for condition_estimate is reasonable
if not isinstance(condition_estimate, float):
raise TypeError(
f"If specified, input ``condition_estimate`` must be a float but is {type(condition_estimate)}."
)
# early out for the non-distributed case
if not A.is_distributed():
U, s, vh = torch.linalg.svd(A.larray, full_matrices=False)
U @= vh
H = vh.T @ torch.diag(s) @ vh
if calcH:
return factories.array(U, is_split=None, comm=A.comm), factories.array(
H, is_split=None, comm=A.comm
)
else:
return factories.array(U, is_split=None, comm=A.comm)
alpha = _estimate_largest_singularvalue(A).item()
if condition_estimate <= 1.0:
kappa = condest(A).item()
else:
kappa = condition_estimate
if A.comm.rank == 0 and not silent:
print(
f"Condition number estimate: {kappa:2.2e} / Estimate for largest singular value: {alpha:2.2e}."
)
# each of these communicators has size r, along these communicators we parallelize the r many QR decompositions that are performed in parallel
horizontal_comm = A.comm.Split(A.comm.rank // r, A.comm.rank)
# each of these communicators has size MPI_WORLD.size / r and will carray a full copy of X for QR decomposition
vertical_comm = A.comm.Split(A.comm.rank % r, A.comm.rank)
# in each horizontal communicator, collect the local array of X from all processes
local_shapes = horizontal_comm.allgather(A.lshape[A.split])
new_local_shape = (
(sum(local_shapes), A.shape[1]) if A.split == 0 else (A.shape[0], sum(local_shapes))
)
counts = tuple(local_shapes)
displacements = tuple(np.cumsum([0] + list(local_shapes))[:-1])
X_collected_local = torch.zeros(
new_local_shape, dtype=A.dtype.torch_type(), device=A.device.torch_device
)
horizontal_comm.Allgatherv(
A.larray, (X_collected_local, counts, displacements), recv_axis=A.split
)
X = factories.array(X_collected_local, is_split=A.split, comm=vertical_comm)
X.balance_()
X /= alpha
# iteration counter and maximum number of iterations
it = 0
itmax = _zolopd_n_iterations(r, kappa)
# parameters and coefficients, see Ref. [1] for their meaning
ell = 1.0 / kappa
c, a, Mhat = _compute_zolotarev_coefficients(r, ell, A.device, dtype=A.dtype)
itmax = _zolopd_n_iterations(r, kappa)
while it < itmax:
it += 1
if not silent:
if A.comm.rank == 0:
print(f"Starting Zolotarev-PD iteration no. {it}...")
# remember current X for later convergence check
X_old = X.copy()
cId = factories.eye(X.shape[1], dtype=X.dtype, comm=X.comm, split=X.split, device=X.device)
cId *= c[2 * horizontal_comm.rank].item() ** 0.5
X = concatenate([X, cId], axis=0)
del cId
if X.split == 0:
Q, R = qr(X)
del R
Q1 = Q[: A.shape[0], :].balance()
Q2 = Q[A.shape[0] :, :].transpose().balance()
Q1Q2 = matmul(Q1, Q2)
del Q1, Q2
X = X[: A.shape[0], :].balance()
X /= r
else:
_in_place_qr_with_q_only(X)
Q1 = X[: A.shape[0], :].balance()
Q2 = X[A.shape[0] :, :].transpose().balance()
del X
Q1Q2 = matmul(Q1, Q2)
del Q1, Q2
X = X_old / r
X += a[horizontal_comm.rank].item() / c[2 * horizontal_comm.rank].item() ** 0.5 * Q1Q2
del Q1Q2
X *= Mhat.item()
# finally, sum over the horizontal communicators
horizontal_comm.Allreduce(MPI.IN_PLACE, X.larray, op=MPI.SUM)
# check for convergence and break if tolerance is reached
if it > 1 and matrix_norm(X - X_old, ord="fro") / matrix_norm(X, ord="fro") <= tol ** (
1 / (2 * r + 1)
):
if not silent:
if A.comm.rank == 0:
print(f"Zolotarev-PD iteration converged after {it} iterations.")
break
elif it < itmax:
# if another iteration is necessary, update coefficients and parameters for next iteration
ellold = ell
ell = 1
for j in range(r):
ell *= (ellold**2 + c[2 * j + 1].item()) / (ellold**2 + c[2 * j].item())
ell *= Mhat.item() * ellold
if ell >= 1.0:
ell = 1.0 - tol
c, a, Mhat = _compute_zolotarev_coefficients(r, ell, A.device, dtype=A.dtype)
else:
if not silent:
if A.comm.rank == 0:
print(
f"Zolotarev-PD iteration did not reach the convergence criterion after {itmax} iterations, which is most likely due to limited numerical accuracy and/or poor estimation of the condition number. The result may still be useful, but should be handled with care!"
)
# as every process has much more data than required, we need to split the result into the parts that are actually
counts = [
X.lshape[X.split] // horizontal_comm.size + (r < X.lshape[X.split] % horizontal_comm.size)
for r in range(horizontal_comm.size)
]
displacements = [sum(counts[:r]) for r in range(horizontal_comm.size)]
if A.split == 1:
U_local = X.larray[
:,
displacements[horizontal_comm.rank] : displacements[horizontal_comm.rank]
+ counts[horizontal_comm.rank],
]
else:
U_local = X.larray[
displacements[horizontal_comm.rank] : displacements[horizontal_comm.rank]
+ counts[horizontal_comm.rank],
:,
]
U = factories.array(U_local, is_split=A.split, comm=A.comm, device=A.device)
del X
U.balance_()
# postprocessing: compute H if requested
if calcH:
H = matmul(U.T, A)
H = 0.5 * (H + H.T.resplit(H.split))
return U, H.resplit(A.split)
else:
return U