"""
Basic linear algebra operations on distributed ``DNDarray``
"""
import itertools
import numpy as np
import torch
import warnings
from typing import List, Callable, Union, Optional, Tuple
from torch._C import Value
from ..communication import MPI
from .. import arithmetics
from .. import complex_math
from .. import constants
from .. import exponential
from ..dndarray import DNDarray
from .. import factories
from .. import manipulations
from .. import rounding
from .. import sanitation
from .. import statistics
from .. import stride_tricks
from .. import types
from ..random import randn
from .qr import qr
from .solver import solve_triangular
__all__ = [
"condest",
"cross",
"det",
"dot",
"inv",
"matmul",
"matrix_norm",
"matrix_exp",
"expm",
"norm",
"outer",
"projection",
"trace",
"transpose",
"tril",
"triu",
"vdot",
"vecdot",
"vector_norm",
]
def _estimate_largest_singularvalue(A: DNDarray, algorithm: str = "fro") -> DNDarray:
"""
Computes an upper estimate for the largest singular value of the input 2D DNDarray.
Parameters
----------
A : DNDarray
The matrix, i.e., a 2D DNDarray, for which the largest singular value should be estimated.
algorithm : str
The algorithm to use for the estimation. Currently, only "fro" (default) is implemented.
If "fro" is chosen, the Frobenius norm of the matrix is used as an upper estimate.
"""
if not isinstance(algorithm, str):
raise TypeError(
f"Parameter 'algorithm' needs to be a string, but is {algorithm} with data type {type(algorithm)}."
)
if algorithm == "fro":
return matrix_norm(A, ord="fro").squeeze()
else:
raise NotImplementedError("So far only algorithm='fro' implemented.")
[docs]
def condest(
A: DNDarray, p: Union[int, str] = None, algorithm: str = "randomized", params: list = None
) -> DNDarray:
"""
Computes a (possibly randomized) upper estimate of the l2-condition number of the input 2D DNDarray.
Parameters
----------
A : DNDarray
The matrix, i.e., a 2D DNDarray, for which the condition number shall be estimated.
p : int or str (optional)
The norm to use for the condition number computation. If None, the l2-norm (default, p=2) is used.
So far, only p=2 is implemented.
algorithm : str
The algorithm to use for the estimation. Currently, only "randomized" (default) is implemented.
params : dict (optional)
A list of parameters required for the chosen algorithm; if not provided, default values for the respective algorithm are chosen.
If `algorithm="randomized"` the number of random samples to use can be specified under the key "nsamples"; default is 10.
Notes
-----
The "randomized" algorithm follows the approach described in [1]; note that in the paper actually the condition number w.r.t. the Frobenius norm is estimated.
However, this yields an upper bound for the condition number w.r.t. the l2-norm as well.
References
----------
[1] T. Gudmundsson, C. S. Kenney, and A. J. Laub. Small-Sample Statistical Estimates for Matrix Norms. SIAM Journal on Matrix Analysis and Applications 1995 16:3, 776-792.
"""
if p is None:
p = 2
if p != 2:
raise ValueError(
f"Only the case p=2 (condition number w.r.t. the euclidean norm) is implemented so far, but input was p={p} (type: {type(p)})."
)
if not isinstance(algorithm, str):
raise TypeError(
f"Parameter 'algorithm' needs to be a string, but is {algorithm} with data type {type(algorithm)}."
)
if algorithm == "randomized":
if params is None:
nsamples = 10 # set default value
else:
if not isinstance(params, dict) or "nsamples" not in params:
raise TypeError(
"If not None, 'params' needs to be a dictionary containing the number of samples under the key 'nsamples'."
)
if not isinstance(params["nsamples"], int) or params["nsamples"] <= 0:
raise ValueError(
f"The number of samples needs to be a positive integer, but is {params['nsamples']} with data type {type(params['nsamples'])}."
)
nsamples = params["nsamples"]
m = A.shape[0]
n = A.shape[1]
if n > m:
# the algorithm only works for m >= n, but fortunately, the condition number (w.r.t. l2-norm) is invariant under transposition
return condest(A.T, p=p, algorithm=algorithm, params=params)
_, R = qr(A, mode="r") # only R factor is computed in QR
# random samples from unit sphere
# regarding the split: if A.split == 1, then n is probably large and we should split along an axis of size n; otherwise, both n and nsamples should be small
Q, R_not_used = qr(
randn(
n,
nsamples,
dtype=A.dtype,
split=0 if A.split == 1 else None,
device=A.device,
comm=A.comm,
)
)
del R_not_used
est = (
matrix_norm(R @ Q)
* A.dtype((m / nsamples) ** 0.5, comm=A.comm)
* matrix_norm(solve_triangular(R, Q))
)
return est.squeeze()
else:
raise NotImplementedError(
"So far only algorithm='randomized' is implemented. Please open an issue on GitHub if you would like to suggest implementing another algorithm."
)
[docs]
def cross(
a: DNDarray, b: DNDarray, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int = -1
) -> DNDarray:
"""
Returns the cross product. 2D vectors will we converted to 3D.
Parameters
----------
a : DNDarray
First input array.
b : DNDarray
Second input array. Must have the same shape as 'a'.
axisa: int
Axis of `a` that defines the vector(s). By default, the last axis.
axisb: int
Axis of `b` that defines the vector(s). By default, the last axis.
axisc: int
Axis of the output containing the cross product vector(s). By default, the last axis.
axis : int
Axis that defines the vectors for which to compute the cross product. Overrides `axisa`, `axisb` and `axisc`. Default: -1
Raises
------
ValueError
If the two input arrays don't match in shape, split, device, or comm. If the vectors are along the split axis.
TypeError
If 'axis' is not an integer.
Examples
--------
>>> a = ht.eye(3)
>>> b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
>>> cross = ht.cross(a, b)
DNDarray([[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.]], dtype=ht.float32, device=cpu:0, split=None)
"""
sanitation.sanitize_in(a)
sanitation.sanitize_in(b)
if a.device != b.device:
raise ValueError(f"'a' and 'b' must have the same device type, {a.device} != {b.device}")
if a.comm != b.comm: # pragma: no cover
raise ValueError(f"'a' and 'b' must have the same comm, {a.comm} != {b.comm}")
a_2d, b_2d = False, False
a_shape, b_shape = list(a.shape), list(b.shape)
if axis != -1 or torch.unique(torch.tensor([axisa, axisb, axisc, axis])).numel() == 1:
axis = stride_tricks.sanitize_axis(a.shape, axis)
axisa, axisb, axisc = (axis,) * 3
else:
axisa = stride_tricks.sanitize_axis(a.shape, axisa)
axisb = stride_tricks.sanitize_axis(b.shape, axisb)
axisc = stride_tricks.sanitize_axis(a.shape, axisc)
if a.split == axisa or b.split == axisb:
raise ValueError(
"The computation of the cross product with vectors along the split axis is not supported."
)
# all dimensions except axisa, axisb must be broadcastable
del a_shape[axisa], b_shape[axisb]
output_shape = stride_tricks.broadcast_shape(a_shape, b_shape)
# 2d -> 3d vector
if a.shape[axisa] == 2:
a_2d = True
shape = tuple(1 if i == axisa else j for i, j in enumerate(a.shape))
a = manipulations.concatenate(
[a, factories.zeros(shape, dtype=a.dtype, device=a.device, comm=a.comm)], axis=axisa
)
if b.shape[axisb] == 2:
b_2d = True
shape = tuple(1 if i == axisb else j for i, j in enumerate(b.shape))
b = manipulations.concatenate(
[b, factories.zeros(shape, dtype=b.dtype, device=b.device)], axis=axisb
)
if axisc != axisa:
a = manipulations.moveaxis(a, axisa, axisc)
if axisc != axisb:
b = manipulations.moveaxis(b, axisb, axisc)
axis = axisc
# by now split axes must be aligned
if a.split != b.split:
raise ValueError(f"'a' and 'b' must have the same split, {a.split} != {b.split}")
if not (a.is_balanced and b.is_balanced):
# TODO: replace with sanitize_redistribute after #888 is merged
b = manipulations.redistribute(b, b.lshape_map, a.lshape_map)
promoted = torch.promote_types(a.larray.dtype, b.larray.dtype)
ret = torch.cross(a.larray.type(promoted), b.larray.type(promoted), dim=axis)
# if both vector axes have dimension 2, return the z-component of the cross product
if a_2d and b_2d:
z_slice = [slice(None, None, None)] * ret.ndim
z_slice[axisc] = -1
ret = ret[tuple(z_slice)]
else:
output_shape = output_shape[:axis] + (3,) + output_shape[axis:]
ret = DNDarray(ret, output_shape, types.heat_type_of(ret), a.split, a.device, a.comm, True)
return ret
[docs]
def det(a: DNDarray) -> DNDarray:
"""
Returns the determinant of a square matrix.
Parameters
----------
a : DNDarray
A square matrix or a stack of matrices. Shape = (...,M,M)
Raises
------
RuntimeError
If the dtype of 'a' is not floating-point.
RuntimeError
If `a.ndim < 2` or if the length of the last two dimensions is not the same.
Examples
--------
>>> a = ht.array([[-2, -1, 2], [2, 1, 4], [-3, 3, -1]])
>>> ht.linalg.det(a)
DNDarray(54., dtype=ht.float64, device=cpu:0, split=None)
"""
sanitation.sanitize_in(a) # pragma: no cover
if a.ndim < 2:
raise RuntimeError("DNDarray must be at least two-dimensional.")
m, n = a.shape[-2:]
if m != n:
raise RuntimeError("Last two dimensions of the DNDarray must be square.")
if types.heat_type_is_exact(a.dtype):
raise RuntimeError("dtype of DNDarray must be floating-point.")
# no split in the square matrices
if not a.is_distributed() or a.split < a.ndim - 2:
data = torch.linalg.det(a.larray)
sp = a.split if a.is_distributed() else None
return DNDarray(
data,
a.shape[:-2],
types.heat_type_of(data),
split=sp,
device=a.device,
comm=a.comm,
balanced=a.balanced,
)
acopy = a.copy()
acopy = manipulations.reshape(acopy, (-1, m, m), new_split=a.split - a.ndim + 3)
adet = factories.ones(acopy.shape[0], dtype=a.dtype, device=a.device, comm=a.comm)
for k in range(adet.shape[0]):
m = 0
for i in range(n):
# partial pivoting
if np.isclose(acopy[k, i, i].item(), 0):
abord = True
for j in range(i + 1, n):
if not np.isclose(acopy[k, j, i].item(), 0):
if a.split == a.ndim - 2: # split=0 on square matrix
acopy[k, i, :], acopy[k, j, :] = acopy[k, j, :], acopy[k, i, :].copy()
else: # split=1
acopy.larray[k, i, :], acopy.larray[k, j, :] = (
acopy.larray[k, j, :],
acopy.larray[k, i, :].clone(),
)
abord = False
m += 1
break
if abord:
adet[k] = 0
break
adet[k] *= acopy[k, i, i]
z = acopy[k, i + 1 :, i, None].larray / acopy[k, i, i].item()
acopy[k, i + 1 :, :].larray -= z * acopy[k, i, :].larray
if m % 2 != 0:
adet[k] = -adet[k]
adet = manipulations.reshape(adet, a.shape[:-2])
return adet
[docs]
def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDarray, float]:
"""
Returns the dot product of two ``DNDarrays``.
Specifically,
1. If both a and b are 1-D arrays, it is inner product of vectors.
2. If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or ``a@b`` is preferred.
3. If either a or b is 0-D (scalar), it is equivalent to multiply and using ``multiply(a, b)`` or ``a*b`` is preferred.
Parameters
----------
a : DNDarray
First input DNDarray
b : DNDarray
Second input DNDarray
out : DNDarray, optional
Output buffer.
See Also
--------
vecdot
Supports (vector) dot along an axis.
"""
if isinstance(a, (float, int)) or isinstance(b, (float, int)) or a.ndim == 0 or b.ndim == 0:
# 3. If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred.
if out is not None:
out = a * b
return out
return a * b
elif a.ndim == 1 and b.ndim == 1:
# 1. If both a and b are 1-D arrays, it is inner product of vectors.
if a.split is None and b.split is None:
sl = slice(None)
asl = bsl = sl
# st = 0
else: # at least one of them is split
# todo: scale this by the starting index of the vector and do a lloc getitem
st, _, sl = a.comm.chunk(a.shape, a.split if a.split is not None else b.split)
asl = sl if a.split is None else slice(sl[0].start - st, sl[0].stop - st)
bsl = sl if b.split is None else slice(sl[0].start - st, sl[0].stop - st)
ret = torch.dot(a.lloc[asl], b.lloc[bsl])
if a.is_distributed() or b.is_distributed():
a.comm.Allreduce(MPI.IN_PLACE, ret, MPI.SUM)
if out is not None:
out = DNDarray(ret, (), types.heat_type_of(ret), None, a.device, a.comm, True)
return out
return DNDarray(ret, (), types.heat_type_of(ret), None, a.device, a.comm, True)
elif a.ndim <= 2 and b.ndim <= 2:
# 2. If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred.
ret = matmul(a, b)
if out is not None:
out.larray = ret.larray
out._DNDarray__dtype = ret.dtype
out._DNDarray__split = ret.split
out._DNDarray__device = ret.device
out._DNDarray__comm = ret.comm
return out
return ret
else:
raise NotImplementedError("ht.dot not implemented for N-D dot M-D arrays")
[docs]
def inv(a: DNDarray) -> DNDarray:
"""
Computes the multiplicative inverse of a square matrix.
Parameters
----------
a : DNDarray
Square matrix of floating-point data type or a stack of square matrices. Shape = (...,M,M)
Raises
------
RuntimeError
If the inverse does not exist.
RuntimeError
If the dtype is not floating-point
RuntimeError
If a is not at least two-dimensional or if the lengths of the last two dimensions are not the same.
Examples
--------
>>> a = ht.array([[1.0, 2], [2, 3]])
>>> ht.linalg.inv(a)
DNDarray([[-3., 2.],
[ 2., -1.]], dtype=ht.float32, device=cpu:0, split=None)
"""
sanitation.sanitize_in(a) # pragma: no cover
if a.ndim < 2:
raise RuntimeError("DNDarray must be at least two-dimensional.")
m, n = a.shape[-2:]
if m != n:
raise RuntimeError("Last two dimensions of the DNDarray must be square.")
if types.heat_type_is_exact(a.dtype):
raise RuntimeError("dtype of DNDarray must be floating-point.")
# no split in the square matrices
if not a.is_distributed() or a.split < a.ndim - 2:
try:
data = torch.inverse(a.larray)
except RuntimeError as e:
raise RuntimeError(e)
# torch.linalg.inv does not raise RuntimeError on MPS when inversion fails
if data.is_mps and torch.any(data.isnan()):
raise RuntimeError("linalg.inv: inversion could not be performed")
return DNDarray(
data,
a.shape,
types.heat_type_of(data),
split=a.split,
device=a.device,
comm=a.comm,
balanced=a.balanced,
)
acopy = a.copy()
acopy = manipulations.reshape(acopy, (-1, m, m), new_split=a.split - a.ndim + 3)
ainv = factories.zeros_like(acopy)
for i in range(m):
ainv[:, i, i] = 1
_, displs = acopy.counts_displs()
for k in range(ainv.shape[0]):
rank = 0
for i in range(n):
# partial pivoting
if np.isclose(acopy[k, i, i].item(), 0):
abord = True
for j in range(i + 1, n):
if not np.isclose(acopy[k, j, i].item(), 0):
if a.split == a.ndim - 2: # split=0 on square matrix
ainv[k, i, :], ainv[k, j, :] = ainv[k, j, :], ainv[k, i, :].copy()
acopy[k, i, :], acopy[k, j, :] = acopy[k, j, :], acopy[k, i, :].copy()
else: # split=1
acopy.larray[k, i, :], acopy.larray[k, j, :] = (
acopy.larray[k, j, :],
acopy.larray[k, i, :].clone(),
)
ainv.larray[k, i, :], ainv.larray[k, j, :] = (
ainv.larray[k, j, :],
ainv.larray[k, i, :].clone(),
)
abord = False
break
if abord:
raise RuntimeError("Inverse does not exist")
scale = acopy[k, i, i].item()
# Circumvent an issue with DNDarray setter and getter that caused precision errors
if a.split == a.ndim - 2:
if rank < acopy.comm.size - 1 and i >= displs[rank + 1]:
rank += 1
if acopy.comm.rank == rank:
ainv.larray[k, i - displs[rank], :] /= scale
acopy.larray[k, i - displs[rank], :] /= scale
else:
ainv[k, i, :].larray /= scale
acopy[k, i, :].larray /= scale
factor = acopy[k, i + 1 :, i, None].larray
ainv[k, i + 1 :, :].larray -= factor * ainv[k, i, :].larray
acopy[k, i + 1 :, :].larray -= factor * acopy[k, i, :].larray
# backwards
for i in range(n - 1, 0, -1):
factor = acopy[k, :i, i, None].larray
ainv[k, :i, :].larray -= factor * ainv[k, i, :].larray
acopy[k, :i, :].larray -= factor * acopy[k, i, :].larray
ainv = manipulations.reshape(ainv, a.shape, new_split=a.split)
return ainv
[docs]
def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
"""
Matrix multiplication of two ``DNDarrays``: ``a@b=c`` or ``A@B=c``.
Returns a tensor with the result of ``a@b``. The split dimension of the returned array is
typically the split dimension of a. If both are ``None`` and if ``allow_resplit=False`` then ``c.split`` is also ``None``.
Batched inputs (with batch dimensions being leading dimensions) are allowed; see also the Notes below.
Parameters
----------
a : DNDarray
matrix :math:`L \\times P` or vector :math:`P` or batch of matrices: :math:`B_1 \\times ... \\times B_k \\times L \\times P`
b : DNDarray
matrix :math:`P \\times Q` or vector :math:`P` or batch of matrices: :math:`B_1 \\times ... \\times B_k \\times P \\times Q`
allow_resplit : bool, optional
Whether to distribute ``a`` in the case that both ``a.split is None`` and ``b.split is None``.
Default is ``False``. If ``True``, if both are not split then ``a`` will be distributed in-place along axis 0.
Notes
-----
- For batched inputs, batch dimensions must coincide and if one matrix is split along a batch axis the other must be split along the same axis.
- If ``a`` or ``b`` is a vector the result will also be a vector.
- We recommend to avoid the particular split combinations ``1``-``0``, ``None``-``0``, and ``1``-``None`` (for ``a.split``-``b.split``) due to their comparably high memory consumption, if possible. Applying ``DNDarray.resplit_`` or ``heat.resplit`` on one of the two factors before calling ``matmul`` in these situations might improve performance of your code / might avoid memory bottlenecks.
References
----------
[1] R. Gu, et al., "Improving Execution Concurrency of Large-scale Matrix Multiplication on
Distributed Data-parallel Platforms," IEEE Transactions on Parallel and Distributed Systems,
vol 28, no. 9. 2017. \n
[2] S. Ryu and D. Kim, "Parallel Huge Matrix Multiplication on a Cluster with GPGPU
Accelerators," 2018 IEEE International Parallel and Distributed Processing Symposium
Workshops (IPDPSW), Vancouver, BC, 2018, pp. 877-882.
Examples
--------
>>> a = ht.ones((n, m), split=1)
>>> a[0] = ht.arange(1, m + 1)
>>> a[:, -1] = ht.arange(1, n + 1).larray
[0/1] tensor([[1., 2.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])
[1/1] tensor([[3., 1.],
[1., 2.],
[1., 3.],
[1., 4.],
[1., 5.]])
>>> b = ht.ones((j, k), split=0)
>>> b[0] = ht.arange(1, k + 1)
>>> b[:, 0] = ht.arange(1, j + 1).larray
[0/1] tensor([[1., 2., 3., 4., 5., 6., 7.],
[2., 1., 1., 1., 1., 1., 1.]])
[1/1] tensor([[3., 1., 1., 1., 1., 1., 1.],
[4., 1., 1., 1., 1., 1., 1.]])
>>> linalg.matmul(a, b).larray
[0/1] tensor([[18., 8., 9., 10.],
[14., 6., 7., 8.],
[18., 7., 8., 9.],
[22., 8., 9., 10.],
[26., 9., 10., 11.]])
[1/1] tensor([[11., 12., 13.],
[ 9., 10., 11.],
[10., 11., 12.],
[11., 12., 13.],
[12., 13., 14.]])
"""
sanitation.sanitize_in(a)
sanitation.sanitize_in(b)
batch_dim = max(a.ndim, b.ndim) - 2 # -1 for vector vector multiplication
batched = batch_dim > 0
if batched and a.gshape[:batch_dim] != b.gshape[:batch_dim]:
raise ValueError("Batch dimensions must have the same shape!")
batch_shape = a.gshape[:batch_dim]
# if they are vectors they need to be expanded to be the proper dimensions
vector_flag_a = vector_flag_b = False
# if a.ndim >= 2 or b.ndim >= 2: # other case gets early out
if a.ndim == b.ndim - 1:
vector_flag_a = True
elif b.ndim == a.ndim - 1:
vector_flag_b = True
vector_flag = vector_flag_a or vector_flag_b # run squeeze at the end
if not vector_flag and a.ndim != b.ndim:
raise ValueError("Number of batch dimensions must be the same!")
if batch_dim >= 0: # not vector vector mult
na = a.gshape[-1]
mb = b.gshape[-2] if not vector_flag_b else b.gshape[-1]
if na != mb:
raise ValueError(
f"The last dimension of a ({a.gshape[-1]}) is not the same size as the second-to-last dimension of b. ({b.gshape[-2]})"
)
if batched:
# check for valid batched split of a and b
# if one is split along a batch axis, both matrices must be split along that axis
if (
a.split is not None
and a.split < batch_dim
or b.split is not None
and b.split < batch_dim
) and a.split != b.split: # not the same batch axis for split
raise NotImplementedError(
"Both input matrices have to be split along the same batch axis!"
)
if vector_flag: # batched matrix vector multiplication not supported
raise NotImplementedError(
"Batched matrix-vector multiplication is not supported, try using expand_dims to make it a batched matrix-matrix multiplication."
)
comm = a.comm
ndim = max(a.ndim, b.ndim)
dev = a.device
tdev = dev.torch_device
# determine if a larger type is needed for c
c_type = types.promote_types(a.dtype, b.dtype)
gpu_int_flag = False
if str(dev)[:3] == "gpu":
og_type = c_type
if c_type in [types.uint8, types.int8, types.int16, types.int32]:
c_type = types.float32
gpu_int_flag = True
elif c_type == types.int64:
c_type = types.float64
gpu_int_flag = True
if a.dtype != c_type:
a = c_type(a, device=dev)
if b.dtype != c_type:
b = c_type(b, device=dev)
c = None
# single-process setup, torch matmul
if a.comm.size == 1:
c = factories.array(torch.matmul(a.larray, b.larray), dtype=c_type, device=dev)
# early out for vector vector multiplication
# is this even covered in the tests?
# seems to be used in test_qr
elif a.ndim == 1 and b.ndim == 1:
# make both split 0, do a local mm then a sum
a.resplit_(0)
b.resplit_(0)
res = a.larray @ b.larray
a.comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM)
c = factories.array(res, split=None, device=dev, comm=comm)
elif a.split is None and b.split is None: # None-None
if allow_resplit and not vector_flag: # resplit a to 0
a.resplit_(ndim - 2)
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray
c = factories.zeros(
(*batch_shape, a.gshape[-2], b.gshape[-1]), dtype=c_type, device=dev, comm=comm
)
c.larray[..., slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
else: # torch matmul
c = factories.array(
torch.matmul(a.larray, b.larray),
dtype=c_type,
device=dev,
comm=comm,
)
elif a.split is not None and a.split < batch_dim: # split in batch dimension
c = factories.array(
torch.matmul(a.larray, b.larray),
is_split=a.split,
dtype=c_type,
device=dev,
comm=comm,
)
if c is not None: # early out
if gpu_int_flag:
c = og_type(c, device=dev)
return c
# vector expansions
if vector_flag_a:
a = manipulations.expand_dims(a, axis=batch_dim)
if vector_flag_b:
b = manipulations.expand_dims(b, axis=batch_dim + 1)
c_shape = (*batch_shape, a.gshape[-2], b.gshape[-1])
# one split None => other one is la dimension
if a.split is None or b.split is None:
split = None
is_split = False
if (a.split == ndim - 2 and b.split is None) or (
a.split is None and b.split == ndim - 1
): # 0-None, None-1
split = a.split if a.split is not None else b.split
is_split = True
c = a.larray @ b.larray
elif a.split == ndim - 1 and b.split is None: # 1-None
split = a.split
c = torch.zeros(c_shape, dtype=c_type.torch_type(), device=tdev)
a_idx = comm.chunk(a.shape, a.split)[2]
c += (
a.larray
@ b.larray[
..., a_idx[ndim - 1].start : a_idx[ndim - 1].start + a.lshape[ndim - 1], :
]
)
comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
elif a.split is None and b.split == ndim - 2: # None-0
split = b.split
c = torch.zeros(c_shape, dtype=c_type.torch_type(), device=tdev)
b_idx = b.comm.chunk(b.shape, b.split)[2]
c += (
a.larray[..., b_idx[ndim - 2].start : b_idx[ndim - 2].start + b.lshape[ndim - 2]]
@ b.larray
)
b.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
# early out
if vector_flag: # squeeze only in the la dimensions
# it could be sensible to resplit/rebalance in case a single node gets the whole vector
if split is not None and split > batch_dim: # split in dimension that gets squeezed
split = batch_dim
if c.numel() == 0: # empty tensor cannot be squeezed
c = torch.zeros((*batch_shape, 0), dtype=c_type.torch_type(), device=tdev)
else:
c = c.squeeze(batch_dim)
if c.ndim >= batch_dim + 2:
c = c.squeeze(batch_dim + 1)
c = factories.array(
c,
split=split if not is_split else None,
is_split=split if is_split else None,
dtype=c_type,
device=dev,
comm=comm,
)
if gpu_int_flag:
c = og_type(c, device=dev)
return c
else:
# block sizes dont need to be the same. they just need the same inner dimension (kB)
kB = 0 # redundant?
rem_a, rem_b = 0, 0
if a.split == ndim - 1 and b.split == ndim - 2: # split 10
# if the split direction is the last dim in a and the first dim in b
# the max inner dim (kB) is the min value from the result of the integer division
# of the last dim of a/world size and the first dim of b/world size
kB = min(
[a.gshape[-1] // comm.size, b.gshape[-2] // comm.size]
) # a.gshape[-1] == b.gshape[-2]
elif a.split == ndim - 2 and b.split == ndim - 1: # split 01
kB = a.gshape[-1]
elif a.split == ndim - 1: # split 11
kB = a.gshape[-1] // comm.size
elif b.split == ndim - 2: # split 00
kB = b.gshape[-2] // comm.size
kB = min(
kB, a.gshape[-1]
) # shouldnt this always be kB and be the same as for split 11?
if (kB == 1 and a.lshape[-1] != 1) or a.lshape[
-1
] % kB != 0: # does kb == 1 imply a.lshape[-1] > 1?
rem_a = 1
if (kB == 1 and b.lshape[-2] != 1) or b.lshape[-2] % kB != 0:
rem_b = 1
# get the lshape map to determine what needs to be sent where as well as M and N
# lshape map dims -> {node, a=0 | b=1, lshape}
lshape_map = torch.zeros((comm.size, 2, ndim), dtype=int, device=tdev)
lshape_map[comm.rank, 0, :] = torch.tensor(a.lshape, device=tdev)
lshape_map[comm.rank, 1, :] = torch.tensor(b.lshape, device=tdev)
comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM)
# find mB (first blocking dim for a) and nB (2nd blocking dim for b)
mB = lshape_map[:, 0, -2].min().item() # smallest number of local rows of a on a node
nB = lshape_map[:, 1, -1].min().item() # smallest number of local columns of b on a node
# check for remaining dims in the outside dimensions
rem_a_out, rem_b_out = 0, 0
if a.lshape[-2] % mB != 0 or (kB == 1 and a.lshape[-2] != 1):
rem_a_out = 1
if b.lshape[-1] % nB != 0 or (kB == 1 and b.lshape[-1] != 1):
rem_b_out = 1
# get the flags from all processes
# rem_map dims guide -> {process number, a/b (0/1), dim0/dim1 (0/1), True/False (1/0)
# if there is a remainder in this dimension
rem_map = torch.zeros((comm.size, 2, 2))
rem_map[comm.rank, 0, :] = torch.tensor((rem_a_out, rem_a), device=tdev)
rem_map[comm.rank, 1, :] = torch.tensor((rem_b, rem_b_out), device=tdev)
rem_map_comm = comm.Iallreduce(MPI.IN_PLACE, rem_map, MPI.SUM)
# index_map dims guide -> {process number, a=0/b=1, relevant 1st index, 2nd index}
index_map = torch.zeros((comm.size, 2, 2, 2), dtype=int, device=tdev)
a_idx = comm.chunk(a.shape, a.split)[2]
index_map[comm.rank, 0, 0] = torch.tensor((a_idx[-2].start, a_idx[-2].stop), device=tdev)
index_map[comm.rank, 0, 1] = torch.tensor((a_idx[-1].start, a_idx[-1].stop), device=tdev)
b_idx = comm.chunk(b.shape, b.split)[2]
index_map[comm.rank, 1, 0] = torch.tensor((b_idx[-2].start, b_idx[-2].stop), device=tdev)
index_map[comm.rank, 1, 1] = torch.tensor((b_idx[-1].start, b_idx[-1].stop), device=tdev)
index_map_comm = comm.Iallreduce(MPI.IN_PLACE, index_map, MPI.SUM)
# output: c = a @ b
# for the communication scheme, the output array needs to be created
c_shape = (*batch_shape, a.gshape[-2], b.gshape[-1])
c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=dev, comm=comm)
# get the index map for c
c_index_map = factories.zeros((c.comm.size, 2, 2), device=dev, comm=comm)
c_idx = comm.chunk(c.shape, c.split)[2]
c_index_map[comm.rank, 0, :] = (c_idx[-2].start, c_idx[-2].stop)
c_index_map[comm.rank, 1, :] = (c_idx[-1].start, c_idx[-1].stop)
c_index_map_comm = comm.Iallreduce(MPI.IN_PLACE, c_index_map, MPI.SUM)
if a.split == ndim - 2:
a_block_map = torch.zeros(
(comm.size, a.shape[-2] // mB // comm.size, a.shape[-1] // kB, 2),
dtype=torch.int,
device=tdev,
)
elif a.split == ndim - 1: # else should be equivalent at this point
a_block_map = torch.zeros(
(comm.size, a.shape[-2] // mB, a.shape[-1] // kB // comm.size, 2),
dtype=torch.int,
device=tdev,
)
# units-> [process, dim0 block number, dim1 block number, start coord] **indices are local
# below is to handle the edge case where there is only one element in one dimension of a
a_d0_1s_flag, a_d1_1s_flag = False, False
if any(lshape_map[:, 0, :][:, -2] == 1):
a_d0_1s_flag = True
if any(lshape_map[:, 0, :][:, -1] == 1):
a_d1_1s_flag = True
index_map_comm.Wait()
for pr in range(comm.size):
start0 = index_map[pr, 0, 0, 0].item()
stop0 = index_map[pr, 0, 0, 1].item()
start1 = index_map[pr, 0, 1, 0].item()
stop1 = index_map[pr, 0, 1, 1].item()
# maybe we could use torch.arange instead of this nested loop
for dim0 in range(
(stop0 - start0) // mB // comm.size if a_d0_1s_flag else (stop0 - start0) // mB
):
# loop over the number of blocks in the 0th dimension
for dim1 in range(
(stop1 - start1) // kB // comm.size if a_d1_1s_flag else (stop1 - start1) // kB
):
# loop over the number of blocks in the 1st dimension
a_block_map[pr, dim0, dim1] = torch.tensor(
(dim0 * mB, dim1 * kB), dtype=torch.int, device=tdev
)
rem_map_comm.Wait()
if b.split == ndim - 2:
# the blocks are shifted in the 2nd dimension of A for as many remainders
# there are between the blocks in the first dim of B
cnt = 0
for r in rem_map[:, 1, 0]:
if r.item():
cnt += 1
# why increment by exactly 1? what can we assume about the lshapes on different nodes?
# can the sizes in the split dimension differ by more than 1?
a_block_map[:, :, cnt:, 1] += 1
b_block_map = torch.zeros(
(comm.size, b.shape[-2] // kB // comm.size, b.shape[-1] // nB, 2),
dtype=torch.int,
device=tdev,
)
else: # b split 1
b_block_map = torch.zeros(
(comm.size, b.shape[-2] // kB, b.shape[-1] // nB // comm.size, 2),
dtype=torch.int,
device=tdev,
)
# units-> [process, dim0 block number, dim1 block number, start coord] **indices are local
# below is to handle the edge case where there is only one element in one dimension of b
b_d0_1s_flag, b_d1_1s_flag = False, False
if any(lshape_map[:, 1, :][:, -2] == 1):
b_d0_1s_flag = True
if any(lshape_map[:, 1, :][:, -1] == 1):
b_d1_1s_flag = True
for pr in range(b.comm.size):
start0 = index_map[pr, 1, 0, 0].item()
stop0 = index_map[pr, 1, 0, 1].item()
start1 = index_map[pr, 1, 1, 0].item()
stop1 = index_map[pr, 1, 1, 1].item()
# loop over the number of blocks in the 0th dimension
for dim0 in range(
(stop0 - start0) // kB // b.comm.size if b_d0_1s_flag else (stop0 - start0) // kB
):
# loop over the number of blocks in the 1st dimension
for dim1 in range(
(stop1 - start1) // nB // b.comm.size
if b_d1_1s_flag
else (stop1 - start1) // nB
):
b_block_map[pr, dim0, dim1] = torch.tensor(
(dim0 * kB, dim1 * nB), dtype=torch.int, device=tdev
)
if a.split == ndim - 1:
cnt = 0
# this loop will push the blocks in B to adjust for the remainders in A
for r in rem_map[:, 0, 1]:
if r.item():
cnt += 1
b_block_map[:, cnt:, :, 0] += 1
# work loop: loop over all processes (also will incorporate the remainder calculations)
c_index_map_comm.Wait()
# split la dims 00
if a.split == ndim - 2 and b.split == ndim - 2:
# need to send b here and not a
# the rows on 'a' are complete, and the columns of 'b' are split
# locations of the remainders in b
b_rem_locs0 = torch.nonzero(rem_map[:, 1, 0] == 1, as_tuple=False)
a_rem_locs0 = torch.nonzero(rem_map[:, 0, 0] == 1, as_tuple=False)
# remainders for a in the
a_node_rem_s0 = a.larray[..., :mB, kB : (kB + 1) * b_rem_locs0.numel() : kB + 1]
b_rem = torch.empty(
(*batch_shape, b_rem_locs0.numel(), b.lshape[-1]),
dtype=a.dtype.torch_type(),
device=tdev,
)
# this if/elif/else loop is for the handling of
if comm.rank in a_rem_locs0:
# if A is split in dim0 and the rank has a remainder in this direction
r = a.larray[..., -1, :].unsqueeze(-2)
# can we not just set r_loc = -1 instead?
r_loc = index_map[comm.rank, 0, 0, 1] - index_map[comm.rank, 0, 0, 0] - 1
else:
r = None
r_loc = None
req = {}
b_lp_data = {}
for pr in range(comm.size):
# ibcast data on node first
if comm.rank == pr:
b_lp_data[pr] = b.larray.clone()
else:
b_lp_data[pr] = torch.zeros(
(*batch_shape, lshape_map[pr, 1, -2].item(), lshape_map[pr, 1, -1].item()),
dtype=b.dtype.torch_type(),
device=tdev,
)
# sending a to all nodes for b to operate with
req[pr] = comm.Ibcast(b_lp_data[pr], root=pr)
# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].Wait()
# after receiving the last loop's bcast
__mm_c_block_setter(
b_proc=pr - 1,
a_proc=comm.rank,
a_data=a.larray,
b_data=b_lp_data[pr - 1],
b_block_map=b_block_map,
a_block_map=a_block_map,
b_split=0,
a_split=0,
mB=mB,
kB=kB,
nB=nB,
c=c.larray,
)
# check if there is a remainder on b in the previous node
# this loop is intended to get the remainders of b since it is the one being passed
if pr - 1 in b_rem_locs0:
# takes care of the remainders in b as well as dim0 of a
b_rem[..., pr - 1, :] = b_lp_data[pr - 1][..., -1, :]
# this loop is to take care of the remainders in dim0 of a
if a_rem_locs0.nelement() != 0 and r_loc is not None:
st = index_map[pr - 1, 1, 0, 0].item()
sp = index_map[pr - 1, 1, 0, 1].item()
c.larray[..., r_loc.item(), :] += (
r[..., st:sp] @ b_lp_data[pr - 1]
).squeeze(-2)
del b_lp_data[pr - 1]
# need to wait if its the last loop, also need to collect the remainders
if pr == comm.size - 1:
req[pr].Wait()
__mm_c_block_setter(
b_proc=pr,
a_proc=comm.rank,
a_data=a.larray,
b_data=b_lp_data[pr],
b_block_map=b_block_map,
a_block_map=a_block_map,
b_split=0,
a_split=0,
mB=mB,
kB=kB,
nB=nB,
c=c.larray,
)
# check if there is a remainder on b on the last node (there shouldnt be)
if pr in b_rem_locs0:
# this is to save the data from B required by the remainders from dim1 of A
b_rem[..., pr, :] = b_lp_data[pr][..., -1, :]
# this loop is to take care of the remainders in the 0th dimension of A
if a_rem_locs0.nelement() != 0 and r_loc is not None:
st = index_map[pr, 1, 0, 0].item()
sp = index_map[pr, 1, 0, 1].item() # linear algebra dimension 0/1
# code not reachable?
# if split_01_flag:
if False:
st1 = index_map[pr, 1, 1, 0].item()
sp1 = index_map[pr, 1, 1, 1].item()
c.larray[..., r_loc.item(), st1:sp1] += r[..., st:sp] @ b_lp_data[pr]
else:
c.larray[..., r_loc.item(), :] += (
r[..., st:sp] @ b_lp_data[pr]
).squeeze(-2)
# set the final blocks on the last loop, then adjust for the
# the remainders which were collected in b_rem
if b_rem_locs0.numel():
c.larray[..., : a_node_rem_s0.shape[-2], :] += (
a_node_rem_s0 @ b_rem
) # shouldnt shape[0] always be mB?
del b_lp_data[pr]
# split la dims 01
elif a.split == ndim - 2 and b.split == ndim - 1:
# for this case there are no remainders which need to be taken care of
req = {}
b_lp_data = {}
for pr in range(comm.size):
# ibcast data on node first
if comm.rank == pr:
b_lp_data[pr] = b.larray.clone()
else:
b_lp_data[pr] = torch.empty(
(*batch_shape, lshape_map[pr, 1, -2].item(), lshape_map[pr, 1, -1].item()),
dtype=b.dtype.torch_type(),
device=tdev,
)
# sending a to all nodes for b to operate with
req[pr] = comm.Ibcast(b_lp_data[pr], root=pr)
# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].Wait()
# after receiving the last loop's bcast
st0 = index_map[pr - 1, 0, 0, 0].item()
sp0 = index_map[pr - 1, 0, 0, 1].item() + 1
st1 = index_map[pr - 1, 1, 1, 0].item()
sp1 = index_map[pr - 1, 1, 1, 1].item()
c.larray[..., : sp0 - st0, st1:sp1] += a.larray @ b_lp_data[pr - 1]
del b_lp_data[pr - 1]
if pr == comm.size - 1:
req[pr].Wait()
st0 = index_map[pr, 0, 0, 0].item()
sp0 = index_map[pr, 0, 0, 1].item() + 1
st1 = index_map[pr, 1, 1, 0].item()
sp1 = index_map[pr, 1, 1, 1].item()
c.larray[..., : sp0 - st0, st1:sp1] += a.larray @ b_lp_data[pr]
del b_lp_data[pr]
# split la dims 11
elif a.split == ndim - 1 and b.split == ndim - 1:
# for this case, a is sent to b
# this is because 'b' has complete columns and the rows of 'a' are split
# locations of the remainders in b
b_rem_locs1 = torch.nonzero(rem_map[:, 1, 1] == 1, as_tuple=False)
a_rem_locs1 = torch.nonzero(rem_map[:, 0, 1] == 1, as_tuple=False)
b_node_rem_s1 = b.larray[..., kB : (kB + 1) * a_rem_locs1.numel() : kB + 1, :nB]
# b_node_rem_s1 -> remainders for a in the
a_rem = torch.empty(
(*batch_shape, a.lshape[-2], a_rem_locs1.numel()),
dtype=b.dtype.torch_type(),
device=tdev,
)
# this if/elif/else loop is for the handling of
if comm.rank in b_rem_locs1:
# if b is split in dim1 and the rank has a remainder in this direction
r = b.larray[..., -1].unsqueeze(-1)
r_loc = index_map[comm.rank, 1, 1, 1] - index_map[comm.rank, 1, 1, 0] - 1
else:
r = None
r_loc = None
req = {}
a_lp_data = {}
for pr in range(comm.size):
# ibcast data on node first
if a.comm.rank == pr:
a_lp_data[pr] = a.larray.clone()
else:
a_lp_data[pr] = torch.zeros(
(*batch_shape, lshape_map[pr, 0, -2].item(), lshape_map[pr, 0, -1].item()),
dtype=a.dtype.torch_type(),
device=tdev,
)
# sending a to all nodes for b to operate with
req[pr] = comm.Ibcast(a_lp_data[pr], root=pr)
# receive the data from the last loop and do the calculation with that
if pr != 0:
# after receiving the last loop's bcast
req[pr - 1].Wait()
__mm_c_block_setter(
a_proc=pr - 1,
b_proc=comm.rank,
a_data=a_lp_data[pr - 1],
b_data=b.larray,
b_block_map=b_block_map,
a_block_map=a_block_map,
a_split=1,
b_split=1,
mB=mB,
kB=kB,
nB=nB,
c=c.larray,
)
# check if there is a remainder on b in the previous node
# this loop is intended to get the remainders of b since it is the one being passed
if pr - 1 in a_rem_locs1:
# takes care of the remainders in b as well as dim0 of a
a_rem[..., pr - 1] = a_lp_data[pr - 1][..., -1]
# this loop is to take care of the remainders in dim1 of B
if b_rem_locs1.nelement() != 0 and r_loc is not None:
st = index_map[pr - 1, 0, 1, 0].item()
sp = index_map[pr - 1, 0, 1, 1].item()
c.larray[..., r_loc.item()] += (
a_lp_data[pr - 1] @ r[..., st:sp, :]
).squeeze(-1)
del a_lp_data[pr - 1]
# need to wait if its the last loop, also need to collect the remainders
if pr == b.comm.size - 1:
req[pr].Wait()
__mm_c_block_setter(
a_proc=pr,
b_proc=a.comm.rank,
a_data=a_lp_data[pr],
b_data=b.larray,
b_block_map=b_block_map,
a_block_map=a_block_map,
a_split=1,
b_split=1,
mB=mB,
kB=kB,
nB=nB,
c=c.larray,
)
# check if there is a remainder on b on the last node (there shouldnt be)
if pr in a_rem_locs1:
# this is to save the data from B required by the remainders from dim1 of A
a_rem[..., pr] = a_lp_data[pr][..., -1]
# this loop is to take care of the remainders in the 0th dimension of A
if b_rem_locs1.nelement() != 0 and r_loc is not None:
st = index_map[pr, 0, 1, 0].item()
sp = index_map[pr, 0, 1, 1].item()
c.larray[..., r_loc.item()] += (a_lp_data[pr] @ r[..., st:sp, :]).squeeze(
-1
)
# set the final blocks on the last loop, then adjust for the the remainders which were collected in b_rem
if a_rem_locs1.numel():
c.larray[..., : b_node_rem_s1.shape[-1]] += a_rem @ b_node_rem_s1
del a_lp_data[pr]
# split la dims 10
elif a.split == ndim - 1 and b.split == ndim - 2:
# todo: this may create the full matrix on evey process, issue #360
# for this case, only a sum is needed at the end
a_rem_locs1 = torch.nonzero(rem_map[:, 0, 1] == 1, as_tuple=False)
# locations of the remainders in b
b_rem_locs0 = torch.nonzero(rem_map[:, 1, 0] == 1, as_tuple=False)
res = torch.zeros(
(*batch_shape, a.gshape[-2], b.gshape[-1]), dtype=c_type.torch_type(), device=tdev
)
for i in range(a.lshape[-1] // kB):
res += (
a.larray[..., :mB, i * kB : i * kB + kB]
@ b.larray[..., i * kB : i * kB + kB, :nB]
)
if a.comm.rank in a_rem_locs1 and b.comm.rank in b_rem_locs0 and kB > 1:
# these Nones are used to change the dims if the full process is not covered
res += a.larray[..., :, -1, None] @ b.larray[..., None, -1, :]
comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM)
split = a.split if b.gshape[-1] > 1 else ndim - 2
c = factories.array(res, split=split, device=dev, comm=comm)
if vector_flag: # squeeze only in the la dimensions
# it could be sensible to resplit/rebalance in case a single node gets the whole vector
split = c.split
if split is not None and split > batch_dim:
split = batch_dim
c_loc = c.larray
if c_loc.numel() == 0: # empty tensor cannot be squeezed
c_loc = torch.zeros((*batch_shape, 0), dtype=c_type.torch_type(), device=tdev)
else:
c_loc = c_loc.squeeze(batch_dim)
if c_loc.ndim >= batch_dim + 2:
c_loc = c_loc.squeeze(batch_dim + 1)
c = factories.array(c_loc, is_split=split, device=dev, comm=comm)
if gpu_int_flag:
c = og_type(c, device=dev)
return c
def _matmul(self, other):
try:
return matmul(self, other)
except TypeError:
return NotImplemented
DNDarray.__matmul__ = _matmul
DNDarray.__matmul__.__doc__ = matmul.__doc__
DNDarray.__rmatmul__ = lambda self, other: _matmul(other, self)
DNDarray.__rmatmul__.__doc__ = matmul.__doc__
[docs]
def matrix_norm(
x: DNDarray,
axis: Optional[Tuple[int, int]] = None,
keepdims: bool = False,
ord: Optional[Union[int, str]] = None,
) -> DNDarray:
"""
Computes the matrix norm of an array.
Parameters
----------
x : DNDarray
Input array
axis : tuple, optional
Both axes of the matrix. If `None` 'x' must be a matrix. Default: `None`
keepdims : bool, optional
Retains the reduced dimension when `True`. Default: `False`
ord : int, 'fro', 'nuc', optional
The matrix norm order to compute. If `None` the Frobenius norm (`'fro'`) is used. Default: `None`
See Also
--------
norm
Computes the vector or matrix norm of an array.
vector_norm
Computes the vector norm of an array.
Notes
-----
The following norms are supported:
===== ============================
ord norm for matrices
===== ============================
None Frobenius norm
'fro' Frobenius norm
'nuc' nuclear norm
inf max(sum(abs(x), axis=1))
-inf min(sum(abs(x), axis=1))
1 max(sum(abs(x), axis=0))
-1 min(sum(abs(x), axis=0))
===== ============================
The following matrix norms are currently **not** supported:
===== ============================
ord norm for matrices
===== ============================
2 largest singular value
-2 smallest singular value
===== ============================
Raises
------
TypeError
If axis is not a 2-tuple
ValueError
If an invalid matrix norm is given or 'x' is a vector.
Examples
--------
>>> ht.matrix_norm(ht.array([[1, 2], [3, 4]]))
DNDarray([[5.4772]], dtype=ht.float64, device=cpu:0, split=None)
>>> ht.matrix_norm(ht.array([[1, 2], [3, 4]]), keepdims=True, ord=-1)
DNDarray([[4.]], dtype=ht.float64, device=cpu:0, split=None)
"""
sanitation.sanitize_in(x)
if x.ndim < 2:
raise ValueError("Cannot compute a matrix norm of a vector.")
if axis is None:
if x.ndim > 2:
raise ValueError("Cannot infer axis on arrays with more than two dimensions.")
else:
axis = (0, 1)
if (not isinstance(axis, tuple)) or len(axis) != 2:
raise TypeError("'axis' must be a 2-tuple.")
row_axis, col_axis = axis
# dtype = types.promote_types(x.dtype, types.float32)
if ord == 1:
if col_axis > row_axis and not keepdims:
col_axis -= 1
return statistics.max(
arithmetics.sum(rounding.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis,
keepdims=keepdims,
)
elif ord == -1:
if col_axis > row_axis and not keepdims:
col_axis -= 1
return statistics.min(
arithmetics.sum(rounding.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis,
keepdims=keepdims,
)
elif ord == 2:
raise NotImplementedError("The largest singular value can't be computed yet.")
elif ord == -2:
raise NotImplementedError("The smallest singular value can't be computed yet.")
elif ord == constants.inf:
if row_axis > col_axis and not keepdims:
row_axis -= 1
return statistics.max(
arithmetics.sum(rounding.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
keepdims=keepdims,
)
elif ord == -constants.inf:
if row_axis > col_axis and not keepdims:
row_axis -= 1
return statistics.min(
arithmetics.sum(rounding.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
keepdims=keepdims,
)
elif ord in [None, "fro"]:
return exponential.sqrt(
arithmetics.sum((complex_math.conj(x) * x).real, axis=axis, keepdims=keepdims)
)
elif ord == "nuc":
raise NotImplementedError("The nuclear norm can't be computed yet.")
else:
raise ValueError("Invalid norm order for matrices.")
[docs]
def matrix_exp(A: DNDarray) -> DNDarray:
r"""
Computes the matrix exponential of a square matrix.
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`,
this function computes the **matrix exponential** of :math:`A \in \mathbb{K}^{n \times n}`, which is defined as
.. math::
\mathrm{matrix\_exp}(A) = \sum_{k=0}^\infty \frac{1}{k!}A^k \in \mathbb{K}^{n \times n}.
If the matrix :math:`A` has eigenvalues :math:`\lambda_i \in \mathbb{C}`,
the matrix :math:`\mathrm{matrix\_exp}(A)` has eigenvalues :math:`e^{\lambda_i} \in \mathbb{C}`.
Supports input of bfloat16, float, double, cfloat and cdouble dtypes.
Also supports batches of matrices, and if :attr:`A` is a batch of matrices then
the output has the same batch dimensions.
.. note::
A may only be distributed in the batch dimensions.
.. seealso::
:func:`torch.linalg.matrix_exp` is called under the hood on the local data.
Args:
A (DNDarray): DNDarray of shape `(*, n, n)` where `*` is zero or more batch dimensions.
Example::
>>> A = ht.empty((2, 2, 2), split=0)
>>> A[0, :, :] = ht.eye((2, 2))
>>> A[1, :, :] = 2 * ht.eye((2, 2))
>>> ht.linalg.matrix_exp(A)
DNDarray([[[2.7183, 0.0000],
[0.0000, 2.7183]],
[[7.3891, 0.0000],
[0.0000, 7.3891]]], dtype=ht.float32, device=cpu:0, split=0)
"""
sanitation.sanitize_in(A)
if A.is_distributed() and A.split >= A.ndim - 2:
raise ValueError(
f"A of shape {A.shape} may only be distributed in batched dimensions but is distributed in {A.split}"
)
out = factories.empty_like(A)
out.larray[...] = torch.linalg.matrix_exp(A.larray)
return out
expm = matrix_exp # provide alias with name of scipy equivalent
"""Alias for :py:func:`matrix_exp`"""
[docs]
def norm(
x: DNDarray,
axis: Optional[Union[int, Tuple[int, int]]] = None,
keepdims: bool = False,
ord: Optional[Union[int, float, str]] = None,
) -> DNDarray:
"""
Return the vector or matrix norm of an array.
Parameters
----------
x : DNDarray
Input vector
axis : int, tuple, optional
Axes along which to compute the norm. If an integer, vector norm is used. If a 2-tuple, matrix norm is used.
If `None`, it is inferred from the dimension of the array. Default: `None`
keepdims : bool, optional
Retains the reduced dimension when `True`. Default: `False`
ord : int, float, inf, -inf, 'fro', 'nuc'
The norm order to compute. See Notes
See Also
--------
vector_norm
Computes the vector norm of an array.
matrix_norm
Computes the matrix norm of an array.
Notes
-----
The following norms are supported:
===== ============================ ==========================
ord norm for matrices norm for vectors
===== ============================ ==========================
None Frobenius norm L2-norm (Euclidean)
'fro' Frobenius norm --
'nuc' nuclear norm --
inf max(sum(abs(x), axis=1)) max(abs(x))
-inf min(sum(abs(x), axis=1)) min(abs(x))
0 -- sum(x != 0)
1 max(sum(abs(x), axis=0)) L1-norm (Manhattan)
-1 min(sum(abs(x), axis=0)) 1./sum(1./abs(a))
2 -- L2-norm (Euclidean)
-2 -- 1./sqrt(sum(1./abs(a)**2))
other -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
The following matrix norms are currently **not** supported:
===== ============================
ord norm for matrices
===== ============================
2 largest singular value
-2 smallest singular value
===== ============================
Raises
------
ValueError
If 'axis' has more than 2 elements
Examples
--------
>>> from heat import linalg as LA
>>> a = ht.arange(9, dtype=ht.float) - 4
>>> a
DNDarray([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=ht.float32, device=cpu:0, split=None)
>>> b = a.reshape((3, 3))
>>> b
DNDarray([[-4., -3., -2.],
[-1., 0., 1.],
[ 2., 3., 4.]], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a)
DNDarray(7.7460, dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(b)
DNDarray(7.7460, dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(b, ord="fro")
DNDarray(7.7460, dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a, float("inf"))
DNDarray([4.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(b, ht.inf)
DNDarray([9.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a, -ht.inf))
DNDarray([0.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(b, -ht.inf)
DNDarray([2.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a, 1)
DNDarray([20.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(b, 1)
DNDarray([7.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a, -1)
DNDarray([0.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(b, -1)
DNDarray([6.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a, 2)
DNDarray(7.7460, dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a, -2)
DNDarray([0.], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a, 3)
DNDarray([5.8480], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(a, -3)
DNDarray([0.], dtype=ht.float32, device=cpu:0, split=None)
c = ht.array([[ 1, 2, 3],
[-1, 1, 4]])
>>> LA.norm(c, axis=0)
DNDarray([1.4142, 2.2361, 5.0000], dtype=ht.float64, device=cpu:0, split=None)
>>> LA.norm(c, axis=1)
DNDarray([3.7417, 4.2426], dtype=ht.float64, device=cpu:0, split=None)
>>> LA.norm(c, axis=1, ord=1)
DNDarray([6., 6.], dtype=ht.float64, device=cpu:0, split=None)
>>> m = ht.arange(8).reshape(2, 2, 2)
>>> LA.norm(m, axis=(1, 2))
DNDarray([ 3.7417, 11.2250], dtype=ht.float32, device=cpu:0, split=None)
>>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :])
(DNDarray(3.7417, dtype=ht.float32, device=cpu:0, split=None), DNDarray(11.2250, dtype=ht.float32, device=cpu:0, split=None))
"""
sanitation.sanitize_in(x)
ndim = x.ndim
if axis is None:
if ord is None or (ord == 2 and ndim == 1) or (ord == "fro" and ndim == 2):
x = x.flatten()
if types.issubdtype(x.dtype, types.complex):
sqnorm = dot(x.real, x.real) + dot(x.imag, x.imag)
else:
sqnorm = dot(x, x)
ret = exponential.sqrt(sqnorm)
if keepdims:
ret = ret.reshape(ndim * [1])
return ret
elif ndim == 2:
return matrix_norm(x, axis, keepdims, ord)
else:
return vector_norm(x, axis, keepdims, ord)
if isinstance(axis, int) or len(axis) == 1:
return vector_norm(x, axis, keepdims, ord)
elif len(axis) == 2:
return matrix_norm(x, axis, keepdims, ord)
else:
raise ValueError("Improper number of dimensions to norm.")
DNDarray.norm: Callable[[DNDarray], float] = lambda self: norm(self)
DNDarray.norm.__doc__ = norm.__doc__
[docs]
def outer(
a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None, split: Optional[int] = None
) -> DNDarray:
"""
Compute the outer product of two 1-D DNDarrays: :math:`out(i, j) = a(i) \\times b(j)`.
Given two vectors, :math:`a = (a_0, a_1, ..., a_N)` and :math:`b = (b_0, b_1, ..., b_M)`, the outer product is:
.. math::
:nowrap:
\\begin{pmatrix}
a_0 \\cdot b_0 & a_0 \\cdot b_1 & . & . & a_0 \\cdot b_M \\\\
a_1 \\cdot b_0 & a_1 \\cdot b_1 & . & . & a_1 \\cdot b_M \\\\
. & . & . & . & . \\\\
a_N \\cdot b_0 & a_N \\cdot b_1 & . & . & a_N \\cdot b_M
\\end{pmatrix}
Parameters
----------
a : DNDarray
1-dimensional: :math:`N`
Will be flattened by default if more than 1-D.
b : DNDarray
1-dimensional: :math:`M`
Will be flattened by default if more than 1-D.
out : DNDarray, optional
2-dimensional: :math:`N \\times M`
A location where the result is stored
split : int, optional
Split dimension of the resulting DNDarray. Can be 0, 1, or None.
This is only relevant if the calculations are memory-distributed.
Default is ``split=0`` (see Notes).
Notes
-----
Parallel implementation of outer product, assumes arrays are dense.
In the classical (dense) case, one of the two arrays needs to be communicated around the processes in
a ring.
* Sending ``b`` around in a ring results in ``outer`` being split along the rows (``outer.split = 0``).\n
* Sending ``a`` around in a ring results in ``outer`` being split along the columns (``outer.split = 1``).\n
So, if specified, ``split`` defines which ``DNDarray`` stays put and which one is passed around.
If ``split`` is ``None`` or unspecified, the result will be distributed along axis ``0``, i.e. by default ``b`` is
passed around, ``a`` stays put.
Examples
--------
>>> a = ht.arange(4)
>>> b = ht.arange(3)
>>> ht.outer(a, b).larray
(3 processes)
[0/2] tensor([[0, 0, 0],
[0, 1, 2],
[0, 2, 4],
[0, 3, 6]], dtype=torch.int32)
[1/2] tensor([[0, 0, 0],
[0, 1, 2],
[0, 2, 4],
[0, 3, 6]], dtype=torch.int32)
[2/2] tensor([[0, 0, 0],
[0, 1, 2],
[0, 2, 4],
[0, 3, 6]], dtype=torch.int32)
>>> a = ht.arange(4, split=0)
>>> b = ht.arange(3, split=0)
>>> ht.outer(a, b).larray
[0/2] tensor([[0, 0, 0],
[0, 1, 2]], dtype=torch.int32)
[1/2] tensor([[0, 2, 4]], dtype=torch.int32)
[2/2] tensor([[0, 3, 6]], dtype=torch.int32)
>>> ht.outer(a, b, split=1).larray
[0/2] tensor([[0],
[0],
[0],
[0]], dtype=torch.int32)
[1/2] tensor([[0],
[1],
[2],
[3]], dtype=torch.int32)
[2/2] tensor([[0],
[2],
[4],
[6]], dtype=torch.int32)
>>> a = ht.arange(5, dtype=ht.float32, split=0)
>>> b = ht.arange(4, dtype=ht.float64, split=0)
>>> out = ht.empty((5,4), dtype=ht.float64, split=1)
>>> ht.outer(a, b, split=1, out=out)
>>> out.larray
[0/2] tensor([[0., 0.],
[0., 1.],
[0., 2.],
[0., 3.],
[0., 4.]], dtype=torch.float64)
[1/2] tensor([[0.],
[2.],
[4.],
[6.],
[8.]], dtype=torch.float64)
[2/2] tensor([[ 0.],
[ 3.],
[ 6.],
[ 9.],
[12.]], dtype=torch.float64)
"""
# sanitize input
devices = []
for array in [a, b]:
sanitation.sanitize_in(array)
devices.append(array.device)
if devices.count(devices[0]) == 2:
device = devices[0]
else:
raise RuntimeError(
f"input arrays on different devices: input 0 on {devices[0]}, input 1 on {devices[1]}"
)
# sanitize dimensions
# TODO implement is_1D in sanitation module #468
if a.ndim > 1:
a = manipulations.flatten(a)
if b.ndim > 1:
b = manipulations.flatten(b)
if a.ndim == 0 or b.ndim == 0:
raise RuntimeError(f"a, b must be 1-D DNDarrays, but were {a.ndim}-D and {b.ndim}-D")
outer_gshape = (a.gshape[0], b.gshape[0])
t_a = a.larray
t_b = b.larray
t_outer_dtype = torch.promote_types(t_a.dtype, t_b.dtype)
t_a, t_b = t_a.type(t_outer_dtype), t_b.type(t_outer_dtype)
outer_dtype = types.canonical_heat_type(t_outer_dtype)
if out is not None:
sanitation.sanitize_out(out, outer_gshape, split, device)
t_out_dtype = out.larray.dtype
# distributed outer product, dense arrays (TODO: sparse, #384)
if a.comm.is_distributed() and split is not None or a.split is not None or b.split is not None:
# MPI coordinates
rank = a.comm.rank
size = a.comm.size
t_outer_slice = 2 * [slice(None, None, None)]
if a.split is None:
a.resplit_(axis=0)
t_a = a.larray.type(t_outer_dtype)
if b.split is None:
b.resplit_(axis=0)
t_b = b.larray.type(t_outer_dtype)
if split is None:
# Split semantics: default out.split = a.split
split = a.split
if out is not None and out.split is None:
out.resplit_(axis=split)
# calculate local slice of outer product
if split == 0:
lshape_map = b.create_lshape_map()
t_outer_shape = (a.lshape[0], b.gshape[0])
_, _, local_slice = b.comm.chunk(b.gshape, b.split)
t_outer_slice[1] = local_slice[0]
elif split == 1:
lshape_map = a.create_lshape_map()
t_outer_shape = (a.gshape[0], b.lshape[0])
_, _, local_slice = a.comm.chunk(a.gshape, a.split)
t_outer_slice[0] = local_slice[0]
t_outer = torch.zeros(t_outer_shape, dtype=t_outer_dtype, device=t_a.device)
if lshape_map[rank] != 0:
t_outer[tuple(t_outer_slice)] = torch.einsum("i,j->ij", t_a, t_b)
# Ring: fill in missing slices of outer product
# allocate memory for traveling data
if split == 0:
t_b_run = torch.empty(lshape_map[0], dtype=t_outer_dtype, device=t_a.device)
elif split == 1:
t_a_run = torch.empty(lshape_map[0], dtype=t_outer_dtype, device=t_b.device)
for p in range(size - 1):
# prepare for sending
dest_rank = rank + 1 if rank != size - 1 else 0
# prepare for receiving
origin_rank = rank - 1 if rank != 0 else size - 1
actual_origin = origin_rank - p
if origin_rank < p:
actual_origin += size
# blocking send and recv
if split == 0:
b.comm.Send(t_b, dest_rank)
b.comm.Recv(t_b_run, origin_rank)
# buffer from actual_origin could be smaller than allocated buffer
t_b = t_b_run[: lshape_map[actual_origin]]
_, _, remote_slice = b.comm.chunk(
b.gshape, b.split, rank=actual_origin, w_size=size
)
t_outer_slice[1] = remote_slice[0]
elif split == 1:
a.comm.Send(t_a, dest_rank)
a.comm.Recv(t_a_run, origin_rank)
# buffer from actual_origin could be smaller than allocated buffer
t_a = t_a_run[: lshape_map[actual_origin]]
_, _, remote_slice = a.comm.chunk(
a.gshape, a.split, rank=actual_origin, w_size=size
)
t_outer_slice[0] = remote_slice[0]
t_outer[tuple(t_outer_slice)] = torch.einsum("i,j->ij", t_a, t_b)
else:
# outer product, all local
t_outer = torch.einsum("i,j->ij", t_a, t_b)
split = None
outer = DNDarray(
t_outer,
gshape=outer_gshape,
dtype=outer_dtype,
split=split,
device=a.device,
comm=a.comm,
balanced=True,
)
if out is not None:
out.larray = outer.larray.type(t_out_dtype)
return out
return outer
[docs]
def projection(a: DNDarray, b: DNDarray) -> DNDarray:
"""
Projection of vector ``a`` onto vector ``b``
Parameters
----------
a : DNDarray
The vector to be projected. Must be a 1D ``DNDarray``
b : DNDarray
The vector to project onto. Must be a 1D ``DNDarray``
"""
if not isinstance(a, DNDarray) or not isinstance(b, DNDarray):
raise TypeError(f"a, b must be of type ht.DNDarray, but were {type(a)}, {type(b)}")
if len(a.shape) != 1 or len(b.shape) != 1:
raise RuntimeError(
f"a, b must be vectors of length 1, but were {len(a.shape)}, {len(b.shape)}"
)
return (dot(a, b) / dot(b, b)) * b
[docs]
def trace(
a: DNDarray,
offset: Optional[int] = 0,
axis1: Optional[int] = 0,
axis2: Optional[int] = 1,
dtype: Optional[types.datatype] = None,
out: Optional[DNDarray] = None,
) -> Union[DNDarray, float]:
"""
Return the sum along diagonals of the array
If `a` is 2D, the sum along its diagonal with the given offset is returned, i.e. the sum of
elements a[i, i+offset] for all i.
If `a` has more than two dimensions, then the axes specified by `axis1` and `axis2` are used
to determine the 2D-sub-DNDarrays whose traces are returned.
The shape of the resulting array is the same as that of `a` with `axis1` and `axis2` removed.
Parameters
----------
a : array_like
Input array, from which the diagonals are taken
offset : int, optional
Offsets of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0.
axis1: int, optional
Axis to be used as the first axis of the 2D-sub-arrays from which the diagonals
should be taken. Default is the first axis of `a`
axis2 : int, optional
Axis to be used as the second axis of the 2D-sub-arrays from which the diagonals
should be taken. Default is the second two axis of `a`
dtype : dtype, optional
Determines the data-type of the returned array and of the accumulator where the elements are
summed. If `dtype` has value None than the dtype is the same as that of `a`
out: ht.DNDarray, optional
Array into which the output is placed. Its type is preserved and it must be of the right shape
to hold the output
Only applicable if `a` has more than 2 dimensions, thus the result is not a scalar.
If distributed, its split axis might change eventually.
Returns
-------
sum_along_diagonals : number (of defined dtype) or ht.DNDarray
If `a` is 2D, the sum along the diagonal is returned as a scalar
If `a` has more than 2 dimensions, then a DNDarray of sums along diagonals is returned
Examples
--------
2D-case
>>> x = ht.arange(24).reshape((4, 6))
>>> x
DNDarray([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]], dtype=ht.int32, device=cpu:0, split=None)
>>> ht.trace(x)
42
>>> ht.trace(x, 1)
46
>>> ht.trace(x, -2)
31
> 2D-case
>>> x = x.reshape((2, 3, 4))
>>> x
DNDarray([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=ht.int32, device=cpu:0, split=None)
>>> ht.trace(x)
DNDarray([16, 18, 20, 22], dtype=ht.int32, device=cpu:0, split=None)
>>> ht.trace(x, 1)
DNDarray([24, 26, 28, 30], dtype=ht.int32, device=cpu:0, split=None)
>>> ht.trace(x, axis1=0, axis2=2)
DNDarray([13, 21, 29], dtype=ht.int32, device=cpu:0, split=None)
"""
# ----------------------------------------------------------------------------
# SANITATION
# ----------------------------------------------------------------------------
if not isinstance(a, (DNDarray, torch.Tensor, np.ndarray, list, tuple)):
raise TypeError(
f"`a` must be a DNDarray, torch.Tensor, np.ndarray, list or tuple, is {type(a)}"
)
# cast input `a` to DNDarray
elif not isinstance(a, DNDarray):
a = factories.array(a)
# assure correct dimensionality of input
if len(a.lshape) < 2:
raise ValueError(f"`a` must contain at least 2 dimensions, not {len(a.lshape)}")
# sanitize axis1, axis2
if not isinstance(axis1, int):
raise TypeError(f"`axis1` must be integer, not {type(axis1)}")
if not isinstance(axis2, int):
raise TypeError(f"`axis2` must be integer, not {type(axis2)}")
# translate negative to positive indexing (trace axes)
if axis1 < 0:
axis1 = axis1 % a.ndim
if axis2 < 0:
axis2 = axis2 % a.ndim
if axis1 == axis2:
raise ValueError(f"axis1 ({axis1}) and axis2 ({axis2}) cannot be the same.")
if axis1 >= a.ndim:
raise ValueError(f"`axis1` ({axis1}) out of bounds for {a.ndim}-dimensional array.")
if axis2 >= a.ndim:
raise ValueError(f"`axis2` ({axis2}) out of bounds for {a.ndim}-dimensional array.")
# sanitize offset
if not isinstance(offset, int):
raise TypeError(f"`offset` must be an integer, not {type(offset)}")
# sanitize dtype
try:
if dtype is None:
dtype = a.dtype
else:
dtype = types.canonical_heat_type(dtype)
except TypeError: # type cannot be converted to ht.type
raise ValueError(f"`dtype` must be a datatype or None, not {type(dtype)}")
# sanitize out
if out is not None:
if not isinstance(out, DNDarray):
raise TypeError(f"`out` must be a ht.DNDarray or None not {type(out)}")
elif a.ndim == 2:
raise ValueError(
"`out` is not applicable if result is a scalar / input `a` is 2-dimensional"
)
# ----------------------------------------------------------------------------
# ALGORITHM
# ----------------------------------------------------------------------------
# ---------------------------------------------
# CASE 2D input (ignore axis1, axis) => scalar
# ---------------------------------------------
if a.ndim == 2:
# CASE 1.1: offset results into an empty array
if offset <= -a.gshape[0] or offset >= a.gshape[1]:
sum_along_diagonals_t = torch.tensor(
0, dtype=dtype.torch_type(), device=a.device.torch_device
)
# CASE 1.2: non-zero array, call torch.trace on concerned sub-DNDarray
else:
# determine the additional offset created by distribution of `a`
a_sub = a
if a.is_distributed():
offset_split, _, _ = a.comm.chunk(a.gshape, a.split)
if a.split == 0:
offset += offset_split
# a.split == 1
else:
offset -= offset_split
# Calculate resulting/concerned sub-array `a_sub`
if offset > 0:
offset = min(offset, a_sub.lshape[1])
a_sub = factories.array(
a_sub.larray[:, offset:], device=a_sub.device, comm=a_sub.comm
)
elif offset < 0:
offset = min(-offset, a_sub.lshape[0])
a_sub = factories.array(
a_sub.larray[offset:, :], device=a_sub.device, comm=a_sub.comm
)
# calculate trace /partial sum on that sub-array
if 0 not in a_sub.lshape:
sum_along_diagonals_t = torch.trace(a_sub.larray)
# make sure result is of correct dtype
sum_along_diagonals_t = sum_along_diagonals_t.type(dtype.torch_type())
# empty array => result = 0
else:
sum_along_diagonals_t = torch.tensor(
0, dtype=dtype.torch_type(), device=a_sub.device.torch_device
)
# sum up all partial sums
if a.is_distributed():
a.comm.Allreduce(MPI.IN_PLACE, sum_along_diagonals_t, MPI.SUM)
# convert resulting 0-d tensor to (python) scalar
return sum_along_diagonals_t.item()
# -------------------------------
# CASE > 2D => DNDArray
# -------------------------------
# sanitize axis1, axis2 (make sure axis1 < axis2)
if axis1 > axis2:
axis1, axis2 = axis2, axis1
# ----------------------------------
# CASE split axis NOT IN trace axes
# ----------------------------------
# compute each diagonal sum
if not (a.is_distributed() and a.split in (axis1, axis2)):
# extract diagonals
diag_t = torch.diagonal(a.larray, offset=offset, dim1=axis1, dim2=axis2)
# sum them up along the last axis (and convert to given dtype)
last_axis = diag_t.ndim - 1
sum_along_diagonals_t = torch.sum(diag_t, last_axis, dtype=dtype.torch_type())
# -----------------------------
# CASE split axis IN trace axes
# -----------------------------
else:
# combination that would NOT result into array of zeros
if -offset < a.gshape[axis1] or offset < a.gshape[axis2]:
# adapt the offset to distribution
# (to result into required diagonal elements on each process)
offset_split, _, _ = a.comm.chunk(a.gshape, a.split)
if a.split == axis1:
offset += offset_split
else: # a.split == axis2
offset -= offset_split
diag_t = torch.diagonal(a.larray, offset=offset, dim1=axis1, dim2=axis2)
# empty diagonal => create an array of zeros for following summation
if 0 in diag_t.shape:
res_shape = [1 if i == 0 else i for i in diag_t.shape]
diag_t = torch.zeros(res_shape, device=a.device.torch_device)
# create recvbuffer (with correct resulting shape)
sum_along_diagonals_t = torch.clone(diag_t)
res_shape = list(sum_along_diagonals_t.shape)
del res_shape[-1] # as summed up along the last axis
sum_along_diagonals_t = torch.reshape(sum_along_diagonals_t, res_shape)
# Sum up all partial sums (and gather them)
# in out
if out is not None:
result_array = out
# in a
else:
result_array = a
result_array.comm.Allreduce(MPI.IN_PLACE, sum_along_diagonals_t, MPI.SUM)
if result_array.split is None:
split_axis = None
else:
last_axis = sum_along_diagonals_t.ndim - 1
split_axis = result_array.split if result_array.split <= last_axis else last_axis
sum_along_diagonals = factories.array(
sum_along_diagonals_t,
dtype=dtype,
split=split_axis,
comm=result_array.comm,
device=result_array.device,
)
if out is not None:
sanitation.sanitize_out(out, tuple(res_shape), out.split, out.device)
out.larray = sum_along_diagonals.larray
return sum_along_diagonals
if a.is_distributed():
# (...and a.split not in (axis1, axis2))
gather_axis = a.split if a.split < axis2 else a.split - 2
# check if gather_axis is in range of result
if gather_axis >= sum_along_diagonals_t.ndim:
gather_axis = sum_along_diagonals_t.ndim - 1
# Stack all partial results back together along the correct axis
sum_along_diagonals = factories.array(
sum_along_diagonals_t, dtype=dtype, is_split=gather_axis, comm=a.comm, device=a.device
)
# input not distributed
else:
# check if split axis is in range of result
if a.split is not None and a.split >= sum_along_diagonals_t.ndim:
gather_axis = sum_along_diagonals_t.ndim - 1
else:
gather_axis = a.split
# convert torch result back to DNDarray
sum_along_diagonals = factories.array(
sum_along_diagonals_t, dtype=dtype, split=gather_axis, comm=a.comm, device=a.device
)
if out is not None:
# resplit to guarantee correct results
if out.split != gather_axis:
warnings.warn(
f"Split axis of `out` will be changed from {out.split} to {gather_axis} to "
f"guarantee correct results."
)
out.resplit_(gather_axis)
# sanitize out
output_gshape = list(a.gshape)
del output_gshape[axis1], output_gshape[axis2 - 1]
sanitation.sanitize_out(out, tuple(output_gshape), gather_axis, out.device)
# store result
out.larray = sum_along_diagonals_t
return out
return sum_along_diagonals
# inline function
DNDarray.trace: Callable[
[
DNDarray,
Optional[int],
Optional[int],
Optional[int],
Optional[types.datatype],
Optional[DNDarray],
],
Union[DNDarray, float],
] = lambda self, offset=0, axis1=0, axis2=1, dtype=None, out=None: trace(
self, offset, axis1, axis2, dtype, out
)
DNDarray.trace.__doc__ = trace.__doc__
@torch.jit.script
def __mm_c_block_setter(
b_proc: int,
a_proc: int,
a_data: torch.Tensor,
b_data: torch.Tensor,
b_block_map: torch.Tensor,
a_block_map: torch.Tensor,
b_split: int,
a_split: int,
mB: int,
kB: int,
nB: int,
c: torch.Tensor,
) -> None:
"""
Helper function for multiplying elements of A and B (see :func:'matmul <matmul>') and putting the results into the
correct place in C.
Parameters
----------
b_proc : int
process with the data for the data for element b
a_proc : int
process with the data for the data for element a
a_data : torch.Tensor
data from A
b_data : torch.Tensor
data from B
b_block_map : torch.Tensor
block map for B
a_block_map : torch.Tensor
block map for A
b_split : int
split of B (0 or 1)
a_split : int
split of A (0 or 1)
mB : int
block size of m
kB : int
block size of K
nB : int
block size of n
c : torch.Tensor
the local data for C
"""
# # (int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, int, int, torch.Tensor) -> None
shp_b = b_block_map.shape
offset_a = b_proc * shp_b[1] if b_proc != 0 else 0
shp_a = a_block_map.shape
offset_b = a_proc * shp_a[2] if a_proc != 0 else 0
# offsets are the number of blocks in the multiplication direction on previous nodes
for bl_1_a in (
torch.arange(offset_a, offset_a + shp_b[1], dtype=torch.long, device=c.device)
if b_split == 0
else torch.arange(a_block_map[a_proc].shape[0], dtype=torch.long, device=c.device)
):
# offset is the number of blocks on the previous node in the direction of multiplication
for bl_0_a in torch.arange(
a_block_map[a_proc].shape[0], dtype=torch.long, device=c.device
): # dim0
for bl_1_b in torch.arange(
b_block_map[b_proc].shape[1], dtype=torch.long, device=c.device
):
for bl_0_b in (
torch.arange(offset_b, offset_b + shp_a[1], dtype=torch.long, device=c.device)
if a_split == 1
else torch.arange(
b_block_map[b_proc].shape[0], dtype=torch.long, device=c.device
)
):
# this offset is the same as before but for b
a_start1 = int(a_block_map[a_proc, bl_0_a, bl_1_a, 1].item())
a_start0 = int(a_block_map[a_proc, bl_0_a, bl_1_a, 0].item())
a_block = a_data[..., a_start0 : a_start0 + mB, a_start1 : a_start1 + kB]
b_start0 = int(b_block_map[b_proc, bl_0_b, bl_1_b, 0].item())
b_start1 = int(b_block_map[b_proc, bl_0_b, bl_1_b, 1].item())
b_block = b_data[..., b_start0 : b_start0 + kB, b_start1 : b_start1 + nB]
c_start0 = a_start0
c_start1 = b_start1
c[..., c_start0 : c_start0 + mB, c_start1 : c_start1 + nB] += a_block @ b_block
[docs]
def transpose(a: DNDarray, axes: Optional[List[int]] = None) -> DNDarray:
"""
Permute the dimensions of an array.
Parameters
----------
a : DNDarray
Input array.
axes : None or List[int,...], optional
By default, reverse the dimensions, otherwise permute the axes according to the values given.
"""
# type check the input tensor
sanitation.sanitize_in(a)
# set default value for axes permutations
dimensions = len(a.shape)
if axes is None:
axes = tuple(reversed(range(dimensions)))
# if given, sanitize the input
else:
try:
# convert to a list to allow index access
axes = list(axes)
except TypeError:
raise ValueError("axes must be an iterable containing ints")
if len(axes) != dimensions:
raise ValueError("axes do not match tensor shape")
for index, axis in enumerate(axes):
if not isinstance(axis, int):
raise TypeError(f"axis must be an integer, but was {type(axis)}")
elif axis < 0:
axes[index] = axis + dimensions
# infer the new split axis, it is the position of the split axis within the new axes permutation
try:
transposed_split = axes.index(a.split) if a.split is not None else None
except ValueError:
raise ValueError("axes do not match tensor shape")
# try to rearrange the tensor and return a new transposed variant
try:
transposed_data = a.larray.permute(*axes)
transposed_shape = tuple(a.shape[axis] for axis in axes)
return DNDarray(
transposed_data,
transposed_shape,
a.dtype,
transposed_split,
a.device,
a.comm,
a.balanced,
)
# if not possible re- raise any torch exception as ValueError
except (RuntimeError, IndexError) as exception:
raise ValueError(str(exception))
DNDarray.transpose: Callable[[DNDarray, List[int]], DNDarray] = lambda self, axes=None: transpose(
self, axes
)
DNDarray.transpose.__doc__ = transpose.__doc__
DNDarray.T = property(transpose)
# statically allocated index slices for non-iterable dimensions in triangular operations
__index_base = (slice(None), slice(None))
def __tri_op(m: DNDarray, k: int, op: Callable) -> DNDarray:
"""
Generic implementation of triangle operations on a ``DNDarray``. It takes care of input sanitation and non-standard
broadcast behavior of the 2D triangle-operators.
Parameters
----------
m : DNDarray
Input array for which to compute the triangle operator.
k : int, optional
Diagonal above which to apply the triangle operator, ``k<0`` is below and ``k>0`` is above.
op : callable
Implementation of the triangle operator.
Raises
------
TypeError
If the input is not a tensor or the diagonal offset cannot be converted to an integral value.
"""
sanitation.sanitize_in(m)
try:
k = int(k)
except ValueError:
raise TypeError(f"Expected k to be integral, but was {type(k)}")
# chunk the global shape of the tensor to obtain the offset compared to the other ranks
offset, _, _ = m.comm.chunk(m.shape, m.split)
dimensions = len(m.shape)
# manually repeat the input for vectors
if dimensions == 1:
triangle = m.larray.expand(m.shape[0], -1)
if torch.numel(triangle > 0):
triangle = op(triangle, k - offset)
return DNDarray(
triangle,
(m.shape[0], m.shape[0]),
m.dtype,
None if m.split is None else 1,
m.device,
m.comm,
m.balanced,
)
original = m.larray
output = original.clone()
# modify k to account for tensor splits
if m.split is not None:
if m.split + 1 == dimensions - 1:
k += offset
elif m.split == dimensions - 1:
k -= offset
# in case of two dimensions we can just forward the call to the callable
if dimensions == 2:
if torch.numel(original) > 0:
op(original, k, out=output)
# more than two dimensions: iterate over all but the last two to realize 2D broadcasting
else:
ranges = [range(elements) for elements in m.lshape[:-2]]
for partial_index in itertools.product(*ranges):
index = partial_index + __index_base
op(original[index], k, out=output[index])
return DNDarray(output, m.shape, m.dtype, m.split, m.device, m.comm, m.balanced)
[docs]
def tril(m: DNDarray, k: int = 0) -> DNDarray:
"""
Returns the lower triangular part of the ``DNDarray``.
The lower triangular part of the array is defined as the elements on and below the diagonal, the other elements of
the result array are set to 0.
The argument ``k`` controls which diagonal to consider. If ``k=0``, all elements on and below the main diagonal are
retained. A positive value includes just as many diagonals above the main diagonal, and similarly a negative
value excludes just as many diagonals below the main diagonal.
Parameters
----------
m : DNDarray
Input array for which to compute the lower triangle.
k : int, optional
Diagonal above which to zero elements. ``k=0`` (default) is the main diagonal, ``k<0`` is below and ``k>0`` is above.
"""
return __tri_op(m, k, torch.tril)
DNDarray.tril: Callable[[DNDarray, int], DNDarray] = lambda self, k=0: tril(self, k)
DNDarray.tril.__doc__ = tril.__doc__
[docs]
def triu(m: DNDarray, k: int = 0) -> DNDarray:
"""
Returns the upper triangular part of the ``DNDarray``.
The upper triangular part of the array is defined as the elements on and below the diagonal, the other elements of the result array are set to 0.
The argument ``k`` controls which diagonal to consider. If ``k=0``, all elements on and below the main diagonal are
retained. A positive value includes just as many diagonals above the main diagonal, and similarly a negative
value excludes just as many diagonals below the main diagonal.
Parameters
----------
m : DNDarray
Input array for which to compute the upper triangle.
k : int, optional
Diagonal above which to zero elements. ``k=0`` (default) is the main diagonal, ``k<0`` is below and ``k>0`` is above.
"""
return __tri_op(m, k, torch.triu)
DNDarray.triu: Callable[[DNDarray, int], DNDarray] = lambda self, k=0: triu(self, k)
DNDarray.triu.__doc__ = triu.__doc__
[docs]
def vdot(x1: DNDarray, x2: DNDarray) -> DNDarray:
"""
Computes the dot product of two vectors. Higher-dimensional arrays will be flattened.
Parameters
----------
x1 : DNDarray
first input array. If it's complex, it's complex conjugate will be used.
x2 : DNDarray
second input array.
Raises
------
ValueError
If the number of elements is inconsistent.
See Also
--------
dot
Return the dot product without using the complex conjugate.
Examples
--------
>>> a = ht.array([1 + 1j, 2 + 2j])
>>> b = ht.array([1 + 2j, 3 + 4j])
>>> ht.vdot(a, b)
DNDarray([(17+3j)], dtype=ht.complex64, device=cpu:0, split=None)
>>> ht.vdot(b, a)
DNDarray([(17-3j)], dtype=ht.complex64, device=cpu:0, split=None)
"""
x1 = manipulations.flatten(x1)
x2 = manipulations.flatten(x2)
return arithmetics.sum(arithmetics.multiply(complex_math.conjugate(x1), x2))
[docs]
def vecdot(
x1: DNDarray, x2: DNDarray, axis: Optional[int] = None, keepdims: Optional[bool] = None
) -> DNDarray:
"""
Computes the (vector) dot product of two DNDarrays.
Parameters
----------
x1 : DNDarray
first input array.
x2 : DNDarray
second input array. Must be compatible with x1.
axis : int, optional
axis over which to compute the dot product. The last dimension is used if 'None'.
keepdims : bool, optional
If this is set to 'True', the axes which are reduced are left in the result as dimensions with size one.
See Also
--------
dot
NumPy-like dot function.
Examples
--------
>>> ht.vecdot(ht.full((3, 3, 3), 3), ht.ones((3, 3)), axis=0)
DNDarray([[9., 9., 9.],
[9., 9., 9.],
[9., 9., 9.]], dtype=ht.float32, device=cpu:0, split=None)
"""
m = arithmetics.mul(x1, x2)
if axis is None:
axis = m.ndim - 1
return arithmetics.sum(m, axis=axis, keepdims=keepdims)
[docs]
def vector_norm(
x: DNDarray,
axis: Optional[Union[int, Tuple[int]]] = None,
keepdims=False,
ord: Optional[Union[int, float]] = None,
) -> DNDarray:
"""
Computes the vector norm of an array.
Parameters
----------
x : DNDarray
Input array
axis : int, tuple, optional
Axis along which to compute the vector norm. If `None` 'x' must be a vector. Default: `None`
keepdims : bool, optional
Retains the reduced dimension when `True`. Default: `False`
ord : int, float, optional
The norm order to compute. If `None` the euclidean norm (`2`) is used. Default: `None`
See Also
--------
norm
Computes the vector norm or matrix norm of an array.
matrix_norm
Computes the matrix norm of an array.
Notes
-----
The following norms are suported:
===== ==========================
ord norm for vectors
===== ==========================
None L2-norm (Euclidean)
inf max(abs(x))
-inf min(abs(x))
0 sum(x != 0)
1 L1-norm (Manhattan)
-1 1./sum(1./abs(a))
2 L2-norm (Euclidean)
-2 1./sqrt(sum(1./abs(a)**2))
other sum(abs(x)**ord)**(1./ord)
===== ==========================
Raises
------
TypeError
If axis is not an integer or a 1-tuple
ValueError
If an invalid vector norm is given.
Examples
--------
>>> ht.vector_norm(ht.array([1, 2, 3, 4]))
DNDarray([5.4772], dtype=ht.float64, device=cpu:0, split=None)
>>> ht.vector_norm(ht.array([[1, 2], [3, 4]]), axis=0, ord=1)
DNDarray([[4., 6.]], dtype=ht.float64, device=cpu:0, split=None)
"""
sanitation.sanitize_in(x)
if axis is None:
pass
elif isinstance(axis, tuple):
if len(axis) > 1:
raise TypeError("'axis' must be an integer or 1-tuple for vectors.")
else:
try:
axis = int(axis)
except Exception:
raise TypeError("'axis' must be an integer or 1-tuple for vectors.")
if ord == constants.INF:
return statistics.max(rounding.abs(x), axis=axis, keepdims=keepdims)
elif ord == -constants.INF:
return statistics.min(rounding.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return arithmetics.sum(x != 0, axis=axis, keepdims=keepdims).astype(types.float)
elif ord == 1:
return arithmetics.sum(rounding.abs(x), axis=axis, keepdims=keepdims)
elif ord is None or ord == 2:
s = (complex_math.conj(x) * x).real
return exponential.sqrt(arithmetics.sum(s, axis=axis, keepdims=keepdims))
elif isinstance(ord, str):
raise ValueError(f"Norm order {ord} is invalid for vectors")
else:
ret = arithmetics.pow(rounding.abs(x), ord)
ret = arithmetics.sum(ret, axis=axis, keepdims=keepdims)
ret = arithmetics.pow(ret, 1.0 / ord)
return ret