"""
Collection of solvers for systems of linear equations.
"""
import heat as ht
from ..dndarray import DNDarray
from ..sanitation import sanitize_out
from typing import List, Dict, Any, TypeVar, Union, Tuple, Optional
from .. import factories
import torch
__all__ = ["cg", "lanczos", "solve", "solve_triangular"]
[docs]
def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
"""
Conjugate gradients method for solving a system of linear equations :math: `Ax = b`
Parameters
----------
A : DNDarray
2D symmetric, positive definite Matrix
b : DNDarray
1D vector
x0 : DNDarray
Arbitrary 1D starting vector
out : DNDarray, optional
Output Vector
"""
if not isinstance(A, DNDarray) or not isinstance(b, DNDarray) or not isinstance(x0, DNDarray):
raise TypeError(
f"A, b and x0 need to be of type ht.DNDarray, but were {type(A)}, {type(b)}, {type(x0)}"
)
if A.ndim != 2:
raise RuntimeError("A needs to be a 2D matrix")
if b.ndim != 1:
raise RuntimeError("b needs to be a 1D vector")
if x0.ndim != 1:
raise RuntimeError("c needs to be a 1D vector")
r = b - ht.matmul(A, x0)
p = r
rsold = ht.matmul(r, r)
x = x0
for i in range(len(b)):
Ap = ht.matmul(A, p)
alpha = rsold / ht.matmul(p, Ap)
x = x + alpha * p
r = r - alpha * Ap
rsnew = ht.matmul(r, r)
if ht.sqrt(rsnew).item() < 1e-10:
print(f"Residual reaches tolerance in it = {i}")
if out is not None:
out = x
return out
return x
p = r + ((rsnew / rsold) * p)
rsold = rsnew
if out is not None:
out = x
return out
return x
[docs]
def lanczos(
A: DNDarray,
m: int,
v0: Optional[DNDarray] = None,
V_out: Optional[DNDarray] = None,
T_out: Optional[DNDarray] = None,
) -> Tuple[DNDarray, DNDarray]:
r"""
The Lanczos algorithm is an iterative approximation of the solution to the eigenvalue problem, as an adaptation of
power methods to find the m "most useful" (tending towards extreme highest/lowest) eigenvalues and eigenvectors of
an :math:`n \times n` Hermitian matrix, where often :math:`m<<n`.
It returns two matrices :math:`V` and :math:`T`, where:
- :math:`V` is a Matrix of size :math:`n\times m`, with orthonormal columns, that span the Krylow subspace \n
- :math:`T` is a Tridiagonal matrix of size :math:`m\times m`, with coefficients :math:`\alpha_1,..., \alpha_n`
on the diagonal and coefficients :math:`\beta_1,...,\beta_{n-1}` on the side-diagonals\n
Parameters
----------
A : DNDarray
2D Hermitian (if complex) or symmetric positive-definite matrix.
Only distribution along axis 0 is supported, i.e. `A.split` must be `0` or `None`.
m : int
Number of Lanczos iterations
v0 : DNDarray, optional
1D starting vector of Euclidean norm 1. If not provided, a random vector will be used to start the algorithm
V_out : DNDarray, optional
Output Matrix for the Krylow vectors, Shape = (n, m), dtype=A.dtype, must be initialized to zero
T_out : DNDarray, optional
Output Matrix for the Tridiagonal matrix, Shape = (m, m), must be initialized to zero
"""
if not isinstance(A, DNDarray):
raise TypeError(f"A needs to be of type ht.dndarray, but was {type(A)}")
if A.ndim != 2:
raise RuntimeError("A needs to be a 2D matrix")
if A.dtype is ht.int32 or A.dtype is ht.int64:
raise TypeError(f"A can be float or complex, got {A.dtype}")
if not isinstance(m, (int, float)):
raise TypeError(f"m must be int, got {type(m)}")
n, column = A.shape
if n != column:
raise TypeError("Input Matrix A needs to be symmetric positive-definite.")
# output data types: T is always Real
A_is_complex = A.dtype is ht.complex128 or A.dtype is ht.complex64
T_dtype = A.real.dtype
# initialize or sanitize output buffers
if T_out is not None:
sanitize_out(
T_out,
output_shape=(m, m),
output_split=None,
output_device=A.device,
output_comm=A.comm,
)
T = T_out
else:
T = ht.zeros((m, m), dtype=T_dtype, device=A.device, comm=A.comm)
if A.split == 0:
if V_out is not None:
sanitize_out(
V_out,
output_shape=(n, m),
output_split=0,
output_device=A.device,
output_comm=A.comm,
)
V = V_out
else:
# This is done for better memory access in the reorthogonalization Gram-Schmidt algorithm
V = ht.zeros((n, m), split=0, dtype=A.dtype, device=A.device, comm=A.comm)
else:
if A.split == 1:
raise NotImplementedError("Distribution along axis 1 not implemented yet.")
if V_out is not None:
sanitize_out(
V_out,
output_shape=(n, m),
output_split=None,
output_device=A.device,
output_comm=A.comm,
)
V = V_out
else:
V = ht.zeros((n, m), split=None, dtype=A.dtype, device=A.device, comm=A.comm)
if A_is_complex:
if v0 is None:
vr = (
ht.random.rand(n, split=V.split, dtype=T_dtype, device=V.device, comm=V.comm)
+ ht.random.rand(n, split=V.split, dtype=T_dtype, device=V.device, comm=V.comm) * 1j
)
v0 = vr / ht.norm(vr)
elif v0.split != V.split:
v0.resplit_(axis=V.split)
# # 0th iteration
# # vector v0 has Euclidean norm = 1
w = ht.matmul(A, v0)
alpha = ht.dot(ht.conj(w).T, v0)
w = w - alpha * v0
T[0, 0] = alpha.real
V[:, 0] = v0
for i in range(1, int(m)):
beta = ht.norm(w)
if ht.abs(beta) < 1e-10:
# print("Lanczos breakdown in iteration {}".format(i))
# Lanczos Breakdown, pick a random vector to continue
vr = (
ht.random.rand(n, split=V.split, dtype=T_dtype, device=V.device, comm=V.comm)
+ ht.random.rand(n, split=V.split, dtype=T_dtype, device=V.device, comm=V.comm)
* 1j
)
# orthogonalize v_r with respect to all vectors v[i]
for j in range(i):
vi_loc = V._DNDarray__array[:, j]
a = torch.dot(vr.larray, torch.conj(vi_loc))
b = torch.dot(vi_loc, torch.conj(vi_loc))
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM)
vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc
# normalize v_r to Euclidean norm 1 and set as ith vector v
vi = vr / ht.norm(vr)
else:
vr = w
# Reorthogonalization
for j in range(i):
vi_loc = V.larray[:, j]
a = torch.dot(vr._DNDarray__array, torch.conj(vi_loc))
b = torch.dot(vi_loc, torch.conj(vi_loc))
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM)
vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc
vi = vr / ht.norm(vr)
w = ht.matmul(A, vi)
alpha = ht.dot(ht.conj(w).T, vi)
w = w - alpha * vi - beta * V[:, i - 1]
T[i - 1, i] = beta.real
T[i, i - 1] = beta.real
T[i, i] = alpha.real
V[:, i] = vi
else:
if v0 is None:
vr = ht.random.rand(n, split=V.split, dtype=T_dtype, device=V.device, comm=V.comm)
v0 = vr / ht.norm(vr)
elif v0.split != V.split:
v0.resplit_(axis=V.split)
# # 0th iteration
# # vector v0 has Euclidean norm = 1
w = ht.matmul(A, v0)
alpha = ht.dot(w, v0)
w = w - alpha * v0
T[0, 0] = alpha
V[:, 0] = v0
for i in range(1, int(m)):
beta = ht.norm(w)
if ht.abs(beta) < 1e-10:
# print("Lanczos breakdown in iteration {}".format(i))
# Lanczos Breakdown, pick a random vector to continue
vr = ht.random.rand(n, split=V.split, dtype=T_dtype, device=V.device, comm=V.comm)
# orthogonalize v_r with respect to all vectors v[i]
for j in range(i):
vi_loc = V._DNDarray__array[:, j]
a = torch.dot(vr.larray, vi_loc)
b = torch.dot(vi_loc, vi_loc)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM)
vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc
# normalize v_r to Euclidean norm 1 and set as ith vector v
vi = vr / ht.norm(vr)
else:
vr = w
# Reorthogonalization
for j in range(i):
vi_loc = V.larray[:, j]
a = torch.dot(vr._DNDarray__array, vi_loc)
b = torch.dot(vi_loc, vi_loc)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM)
vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc
vi = vr / ht.norm(vr)
w = ht.matmul(A, vi)
alpha = ht.dot(w, vi)
w = w - alpha * vi - beta * V[:, i - 1]
T[i - 1, i] = beta
T[i, i - 1] = beta
T[i, i] = alpha
V[:, i] = vi
if V.split is not None:
V.resplit_(axis=None)
return V, T
[docs]
def solve(A: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
r"""
Computes the solution of a square system of linear equations with a unique solution.
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`,
this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to
:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as
.. math:: AX = B
Supports inputs of integer, float, double, cfloat and cdouble dtypes.
Also supports batches of matrices, and if the inputs are batches of matrices then
the output has the same batch dimensions.
Letting `*` be zero or more batch dimensions,
- If :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(*, n)` (a batch of vectors) or shape
`(*, n, k)` (a batch of matrices or "multiple right-hand sides"), this function returns `X` of shape
`(*, n)` or `(*, n, k)` respectively.
- Otherwise, if :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(n,)` or `(n, k)`, :attr:`B`
is broadcast to have shape `(*, n)` or `(*, n, k)` respectively.
This function then returns the solution of the resulting batch of systems of linear equations.
.. note::
A and b may only be distributed in the batch dimensions. If both are split, they must be split along matching batch axes.
.. seealso::
:func:`torch.linalg.solve` is called under the hood on the local data. This docstring is also heavily inspired by the docstring of this function.
Parameters
----------
A : DNDarray
Matrix to be inverted of shape `(*, n, n)` where `*` is zero or more batch dimensions
b : DNDarray
Right-hand side of shape `(*, n)` or `(*, n, k)` or `(n,)` or `(n, k)`
out : DNDarray, optional
Output Vector
Examples::
>>> A = ht.random.randn(3, 3)
>>> b = ht.random.randn(3)
>>> x = ht.linalg.solve(A, b)
>>> ht.allclose(A @ x, b, atol=1e-5)
True
>>> A = ht.random.randn(2, 3, 3, split=0)
>>> B = ht.random.randn(2, 3, 4, split=0)
>>> X = ht.linalg.solve(A, B)
>>> X.shape
(2, 3, 4)
>>> ht.allclose(A @ X, B, atol=1e-5)
True
>>> A = ht.random.randn(2, 3, 3, split=None)
>>> B = ht.random.randn(2, 3, 4, split=2)
>>> X = ht.linalg.solve(A, B)
>>> X.split
2
>>> ht.allclose(A @ X, B, atol=1e-5)
True
>>> A = ht.random.randn(2, 3, 3, split=0)
>>> b = ht.random.randn(3, 1)
>>> x = ht.linalg.solve(A, b) # b is broadcast to size (2, 3, 1)
>>> x.shape
(2, 3, 1)
>>> x.split
0
>>> ht.allclose((A @ x).resplit_(None), b, atol=1e-5)
True
"""
ht.sanitize_in(A)
ht.sanitize_in(b)
# torch doesn't support integer, so we cast to float here if needed
if ht.issubdtype(A.dtype, ht.integer):
A = A.astype(ht.promote_types(A.dtype, ht.float), copy=True)
if ht.issubdtype(b.dtype, ht.integer):
b = b.astype(ht.promote_types(b.dtype, ht.float), copy=True)
# figure out which is the non-batched axis in b
if b.shape[-1] == A.shape[-1]:
b_non_batched_axis = b.ndim - 1
elif b.ndim == 1:
raise ValueError(f"b has incorrect shape of {b.shape} for A of shape {A.shape}")
elif b.shape[-2] == A.shape[-1]:
b_non_batched_axis = b.ndim - 2
else:
raise ValueError(f"b has incorrect shape of {b.shape} for A of shape {A.shape}")
# raise error if b is distributed in disallowed way
if b.is_distributed() and b.split == b_non_batched_axis:
raise NotImplementedError(
f"b of shape {b.shape} with A of shape {A.shape} is split in {b.split} which is the non-batched axis, but it can only be distributed in batched axes in this implementation. If you require this feature, please open an issue on GitHub."
)
# raise errors if A is distributed in disallowed way
if A.is_distributed():
if A.split > A.ndim - 3:
raise NotImplementedError(
f"A of dimension {A.ndim} is split in {A.split} but must not be distributed in the (non-batched) last two axes in this implementation. If you require this feature, please open an issue on GitHub."
)
elif A.split != b.split and b.is_distributed():
raise ValueError(f"Split of A and b must match, but got {A.split} and {b.split}")
# figure out what the output vector looks like
out_initalization = {"dtype": ht.types.promote_types(A.dtype, b.dtype), "device": b.device}
if b.shape[:b_non_batched_axis] == A.shape[:-2]: # no need for expansion
out_initalization["shape"] = b.shape
out_initalization["split"] = b.split
out_initalization["comm"] = b.comm
elif b_non_batched_axis == 0 and A.ndim > 2: # b needs expanding
out_initalization["shape"] = A.shape[:-2] + b.shape
if A.split is None:
out_initalization["split"] = b.split
out_initalization["comm"] = b.comm
else:
out_initalization["split"] = A.split
out_initalization["comm"] = A.comm
else:
raise ValueError(
f"Don't know how to batch solve with A of shape {A.shape} and split {A.split} and b of shape {b.shape} and split {b.split}"
)
# set up output vector
if out is None:
out = factories.empty(**out_initalization)
ht.sanitize_out(
out,
output_shape=out_initalization["shape"],
output_split=out_initalization["split"],
output_device=out_initalization["device"],
output_comm=out_initalization["comm"],
)
# if out is integer, we may need to cast to float here
if out.dtype != out_initalization["dtype"]:
out = out.astype(out_initalization["dtype"], copy=False)
# do the actual solving of local matrices in torch
try:
torch.linalg.solve(A.larray, b.larray, left=True, out=out.larray)
error_msg = ""
except torch._C._LinAlgError as e:
error_msg = str(e)
global_error_msg = [error_msg for error_msg in ht.comm.allgather(error_msg) if error_msg != ""]
if len(global_error_msg):
raise RuntimeError(*global_error_msg)
return out
[docs]
def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray:
"""
Solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and
`b` a (possibly batched) vector or matrix of suitable shape, both provided as input to the function.
The implementation builts on the corresponding solver in PyTorch and implements an memory-distributed, MPI-parallel block-wise version thereof.
Parameters
----------
A : DNDarray
An upper triangular invertible square (n x n) matrix or a batch thereof, i.e. a ``DNDarray`` of shape `(..., n, n)`.
b : DNDarray
a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A.
(Batched) Vectors have to be provided as ... x n x 1 matrices and the split dimension of b must the second last dimension if not None.
Note
---------
Since such a check might be computationally expensive, we do not check whether A is indeed upper triangular.
If you require such a check, please open an issue on our GitHub page and request this feature.
"""
if not isinstance(A, DNDarray) or not isinstance(b, DNDarray):
raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.")
if not A.ndim >= 2:
raise ValueError("A needs to be a (batched) matrix.")
if not b.ndim == A.ndim:
raise ValueError("b needs to have the same number of (batch) dimensions as A.")
if not A.shape[-2] == A.shape[-1]:
raise ValueError("A needs to be a (batched) square matrix.")
batch_dim = A.ndim - 2
batch_shape = A.shape[:batch_dim]
if not A.shape[:batch_dim] == b.shape[:batch_dim]:
raise ValueError("Batch dimensions of A and b must be of the same shape.")
if b.split == batch_dim + 1:
raise ValueError("split=1 is not allowed for the right hand side.")
if not b.shape[batch_dim] == A.shape[-1]:
raise ValueError("Dimension mismatch of A and b.")
if (
A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim
): # batch split
if A.split != b.split:
raise ValueError(
"If a split dimension is a batch dimension, A and b must have the same split dimension. A possible solution would be a resplit of A or b to the same split dimension."
)
else:
if (
A.split is not None and b.split is not None
): # both la dimensions split --> b.split = batch_dim
# TODO remove?
if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]):
raise RuntimeError(
"The process-local arrays of A and b have different sizes along the splitted axis. This is most likely due to one of the DNDarrays being in unbalanced state. \n Consider using `A.is_balanced(force_check=True)` and `b.is_balanced(force_check=True)` to check if A and b are balanced; \n then call `A.balance_()` and/or `b.balance_()` in order to achieve equal local shapes along the split axis before applying `solve_triangular`."
)
comm = A.comm
dev = A.device
tdev = dev.torch_device
nprocs = comm.Get_size()
if A.split is None: # A not split
if b.split is None:
x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)
return factories.array(x, dtype=b.dtype, device=dev, comm=comm)
else: # A not split, b.split == -2
b_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int64, device=tdev),
torch.cumsum(b.lshape_map[:, -2], 0),
]
)
btilde_loc = b.larray.clone()
A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]]
x = factories.zeros_like(b, device=dev, comm=comm)
for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0
res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)
if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A_loc @ x.larray
comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)
if comm.rank < i:
btilde_loc -= res_recv
if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True
)
return x
if A.split < batch_dim: # batch split
x = factories.zeros_like(b, device=dev, comm=comm, split=A.split)
x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)
return x
if A.split >= batch_dim: # both splits in la dims
A_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int64, device=tdev),
torch.cumsum(A.lshape_map[:, A.split], 0),
]
)
if b.split is None:
btilde_loc = b.larray[
..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], :
].clone()
else: # b is split at la dim 0
btilde_loc = b.larray.clone()
x = factories.zeros_like(
b, device=dev, comm=comm, split=batch_dim
) # split at la dim 0 in case b is not split
if A.split == batch_dim + 1:
for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0
res_send = torch.empty(0)
res_recv = torch.zeros(
(*batch_shape, count[comm.rank], b.shape[-1]),
device=tdev,
dtype=b.dtype.torch_type(),
)
if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A.larray @ x.larray
comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)
if comm.rank < i:
btilde_loc -= res_recv
if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True
)
else: # split dim is la dim 0
for i in range(nprocs - 1, 0, -1):
idims = tuple(x.lshape_map[i])
if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]],
btilde_loc,
upper=True,
)
x_from_i = x.larray
else:
x_from_i = torch.zeros(
idims,
dtype=b.dtype.torch_type(),
device=tdev,
)
comm.Bcast(x_from_i, root=i)
if comm.rank < i:
btilde_loc -= (
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i
)
if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True
)
return x