Source code for heat.core.manipulations

"""
Manipulation operations for (potentially distributed) `DNDarray`s.
"""

from __future__ import annotations

import numpy as np
import torch
import warnings

from typing import Any, Iterable, Type, List, Callable, Union, Tuple, Sequence, Optional

from .communication import MPI, Communication
from .dndarray import DNDarray

from . import arithmetics
from . import factories
from . import indexing
from . import linalg
from . import sanitation
from . import stride_tricks
from . import tiling
from . import types
from . import _operations

__all__ = [
    "balance",
    "broadcast_arrays",
    "broadcast_to",
    "collect",
    "column_stack",
    "concatenate",
    "diag",
    "diagonal",
    "dsplit",
    "expand_dims",
    "flatten",
    "flip",
    "fliplr",
    "flipud",
    "hsplit",
    "hstack",
    "moveaxis",
    "pad",
    "ravel",
    "redistribute",
    "repeat",
    "reshape",
    "resplit",
    "roll",
    "rot90",
    "row_stack",
    "shape",
    "sort",
    "split",
    "squeeze",
    "stack",
    "swapaxes",
    "tile",
    "topk",
    "unfold",
    "unique",
    "vsplit",
    "vstack",
]


[docs] def balance(array: DNDarray, copy=False) -> DNDarray: """ Out of place balance function. More information on the meaning of balance can be found in :func:`DNDarray.balance_() <heat.core.dndarray.DNDarray.balance_()>`. Parameters ---------- array : DNDarray the DNDarray to be balanced copy : bool, optional if the DNDarray should be copied before being balanced. If false (default) this will balance the original array and return that array. Otherwise (true), a balanced copy of the array will be returned. Default: False """ cpy = array.copy() if copy else array cpy.balance_() return cpy
DNDarray.balance = lambda self, copy=False: balance(self, copy) DNDarray.balance.__doc__ = balance.__doc__
[docs] def broadcast_arrays(*arrays: DNDarray) -> List[DNDarray]: """ Broadcasts one or more arrays against one another. Returns the broadcasted arrays, distributed along the split dimension of the first array in the list. If the first array is not distributed, the output will not be distributed. Parameters ---------- arrays : DNDarray An arbitrary number of to-be broadcasted ``DNDarray``s. Notes ----- Broadcasted arrays are a view of the original arrays if possible, otherwise a copy is made. Examples -------- >>> import heat as ht >>> a = ht.ones((100, 10), split=0) >>> b = ht.ones((10,), split=None) >>> c = ht.ones((1, 10), split=1) >>> d, e, f = ht.broadcast_arrays(a, b, c) >>> d.shape (100, 10) >>> e.shape (100, 10) >>> f.shape (100, 10) >>> d.split 0 >>> e.split 0 >>> f.split 0 """ if len(arrays) <= 1: return arrays try: arrays = sanitation.sanitize_distribution(*arrays, target=arrays[0]) output_split, output_comm, output_balanced = ( arrays[0].split, arrays[0].comm, arrays[0].balanced, ) except NotImplementedError as e: raise ValueError(e) gshapes = [] t_arrays = [] for array in arrays: # extract global shapes gshapes.append(array.gshape) # extract local torch tensors t_arrays.append(array.larray) t_arrays = tuple(t_arrays) # broadcast the global shapes try: output_shape = tuple(torch.broadcast_shapes(*gshapes)) except RuntimeError: raise ValueError( f"Shape mismatch: objects cannot be broadcast to a single shape. Original shapes: {gshapes}" ) # broadcast the local torch tensors: this is a view of the original data broadcasted = torch.broadcast_tensors(*t_arrays) out = [] for i in range(len(broadcasted)): out.append( DNDarray( broadcasted[i], gshape=output_shape, dtype=arrays[i].dtype, split=output_split, device=arrays[i].device, comm=output_comm, balanced=output_balanced, ) ) return out
[docs] def broadcast_to(x: DNDarray, shape: Tuple[int, ...]) -> DNDarray: """ Broadcasts an array to a specified shape. Returns a view of ``x`` if ``x`` is not distributed, otherwise it returns a broadcasted, distributed, load-balanced copy of ``x``. Parameters ---------- x : DNDarray `DNDarray` to broadcast. shape : Tuple[int, ...] Array shape. Must be compatible with ``x``. Raises ------ ValueError If the array is not compatible with the new shape according to PyTorch's broadcasting rules. Examples -------- >>> import heat as ht >>> a = ht.arange(100, split=0) >>> b = ht.broadcast_to(a, (10, 100)) >>> b.shape (10, 100) >>> b.split 1 >>> c = ht.broadcast_to(a, (100, 10)) ValueError: Shape mismatch: object cannot be broadcast to the given shape. Original shape: (100,), target shape: (100, 10) """ sanitation.sanitize_in(x) # figure out the output split axis via dndarray.__torch_proxy__ and named tensors functionality torch_proxy = x.__torch_proxy__() split_tags = [None] * x.ndim if x.split is not None: split_tags[x.split] = "split" torch_proxy = torch_proxy.detach().clone().rename_(*split_tags) try: torch_proxy = torch_proxy.broadcast_to(shape) except RuntimeError: raise ValueError( f"Shape mismatch: object cannot be broadcast to the given shape. Original shape: {x.shape}, target shape: {shape}" ) output_split = torch_proxy.names.index("split") else: try: torch_proxy = torch_proxy.broadcast_to(shape) except RuntimeError: raise ValueError( f"Shape mismatch: object cannot be broadcast to the given shape. Original shape: {x.shape}, target shape: {shape}" ) output_split = None if not x.is_distributed(): # return a view of the input data broadcasted = DNDarray( x.larray.broadcast_to(shape), gshape=shape, dtype=x.dtype, split=output_split, device=x.device, comm=x.comm, balanced=True, ) else: # input is distributed, return a broadcasted copy of input # exploit binary operations broadcasting broadcasted = factories.zeros( shape, dtype=x.dtype, split=output_split, device=x.device, comm=x.comm ) broadcasted += x del x return broadcasted
[docs] def collect(arr: DNDarray, target_rank: Optional[int] = 0) -> DNDarray: """ A function collecting a distributed DNDarray to one rank, chosen by the `target_rank` variable. It is a specific case of the ``redistribute_`` method. Parameters ---------- arr : DNDarray The DNDarray to be collected. target_rank : int, optional The rank to which the DNDarray will be collected. Default: 0. Raises ------ TypeError If the target rank is not an integer. ValueError If the target rank is out of bounds. Examples -------- >>> st = ht.ones((50, 81, 67), split=2) >>> print(st.lshape) [0/2] (50, 81, 23) [1/2] (50, 81, 22) [2/2] (50, 81, 22) >>> collected_st = collect(st) >>> print(collected_st) [0/2] (50, 81, 67) [1/2] (50, 81, 0) [2/2] (50, 81, 0) >>> collected_st = collect(collected_st, 1) >>> print(st.lshape) [0/2] (50, 81, 0) [1/2] (50, 81, 67) [2/2] (50, 81, 0) """ arr2 = arr.copy() arr2.collect_(target_rank=target_rank) return arr2
DNDarray.collect = lambda arr, target_rank=0: redistribute(arr, target_rank) DNDarray.collect.__doc__ = collect.__doc__
[docs] def column_stack(arrays: Sequence[DNDarray, ...]) -> DNDarray: """ Stack 1-D or 2-D `DNDarray`s as columns into a 2-D `DNDarray`. If the input arrays are 1-D, they will be stacked as columns. If they are 2-D, they will be concatenated along the second axis. Parameters ---------- arrays : Sequence[DNDarray, ...] Sequence of `DNDarray`s. Raises ------ ValueError If arrays have more than 2 dimensions Notes ----- All `DNDarray`s in the sequence must have the same number of rows. All `DNDarray`s must be split along the same axis! Note that distributed 1-D arrays (`split = 0`) by default will be transposed into distributed column arrays with `split == 1`. See Also -------- :func:`concatenate` :func:`hstack` :func:`row_stack` :func:`stack` :func:`vstack` Examples -------- >>> # 1-D tensors >>> a = ht.array([1, 2, 3]) >>> b = ht.array([2, 3, 4]) >>> ht.column_stack((a, b)).larray tensor([[1, 2], [2, 3], [3, 4]]) >>> # 1-D and 2-D tensors >>> a = ht.array([1, 2, 3]) >>> b = ht.array([[2, 5], [3, 6], [4, 7]]) >>> c = ht.array([[7, 10], [8, 11], [9, 12]]) >>> ht.column_stack((a, b, c)).larray tensor([[ 1, 2, 5, 7, 10], [ 2, 3, 6, 8, 11], [ 3, 4, 7, 9, 12]]) >>> # distributed DNDarrays, 3 processes >>> a = ht.arange(10, split=0).reshape((5, 2)) >>> b = ht.arange(5, 20, split=0).reshape((5, 3)) >>> c = ht.arange(20, 40, split=0).reshape((5, 4)) >>> ht_column_stack((a, b, c)).larray [0/2] tensor([[ 0, 1, 5, 6, 7, 20, 21, 22, 23], [0/2] [ 2, 3, 8, 9, 10, 24, 25, 26, 27]], dtype=torch.int32) [1/2] tensor([[ 4, 5, 11, 12, 13, 28, 29, 30, 31], [1/2] [ 6, 7, 14, 15, 16, 32, 33, 34, 35]], dtype=torch.int32) [2/2] tensor([[ 8, 9, 17, 18, 19, 36, 37, 38, 39]], dtype=torch.int32) >>> # distributed 1-D and 2-D DNDarrays, 3 processes >>> a = ht.arange(5, split=0) >>> b = ht.arange(5, 20, split=1).reshape((5, 3)) >>> ht_column_stack((a, b)).larray [0/2] tensor([[ 0, 5], [0/2] [ 1, 8], [0/2] [ 2, 11], [0/2] [ 3, 14], [0/2] [ 4, 17]], dtype=torch.int32) [1/2] tensor([[ 6], [1/2] [ 9], [1/2] [12], [1/2] [15], [1/2] [18]], dtype=torch.int32) [2/2] tensor([[ 7], [2/2] [10], [2/2] [13], [2/2] [16], [2/2] [19]], dtype=torch.int32) """ arr_dims = list(array.ndim for array in arrays) # sanitation, arrays can be 1-d or 2-d, see sanitation module #468 over_dims = [i for i, j in enumerate(arr_dims) if j > 2] if len(over_dims) > 0: raise ValueError("Arrays must be 1-D or 2-D") if arr_dims.count(1) == len(arr_dims): # all arrays are 1-D, stack return stack(arrays, axis=1) else: if arr_dims.count(1) > 0: arr_1d = [i for i, j in enumerate(arr_dims) if j == 1] # 1-D arrays must be columns arrays = list(arrays) for ind in arr_1d: arrays[ind] = arrays[ind].reshape((1, arrays[ind].size)).T return concatenate(arrays, axis=1)
[docs] def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: """ Join 2 or more `DNDarrays` along an existing axis. Parameters ---------- arrays: Sequence[DNDarray, ...] The arrays must have the same shape, except in the dimension corresponding to axis. axis: int, optional The axis along which the arrays will be joined (default is 0). Raises ------ RuntimeError If the concatenated :class:`~heat.core.dndarray.DNDarray` meta information, e.g. `split` or `comm`, does not match. TypeError If the passed parameters are not of correct type. ValueError If the number of passed arrays is less than two or their shapes do not match. Examples -------- >>> x = ht.zeros((3, 5), split=None) [0/1] tensor([[0., 0., 0., 0., 0.], [0/1] [0., 0., 0., 0., 0.], [0/1] [0., 0., 0., 0., 0.]]) [1/1] tensor([[0., 0., 0., 0., 0.], [1/1] [0., 0., 0., 0., 0.], [1/1] [0., 0., 0., 0., 0.]]) >>> y = ht.ones((3, 6), split=0) [0/1] tensor([[1., 1., 1., 1., 1., 1.], [0/1] [1., 1., 1., 1., 1., 1.]]) [1/1] tensor([[1., 1., 1., 1., 1., 1.]]) >>> ht.concatenate((x, y), axis=1) [0/1] tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.], [0/1] [0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]]) [1/1] tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]]) >>> x = ht.zeros((4, 5), split=1) [0/1] tensor([[0., 0., 0.], [0/1] [0., 0., 0.], [0/1] [0., 0., 0.], [0/1] [0., 0., 0.]]) [1/1] tensor([[0., 0.], [1/1] [0., 0.], [1/1] [0., 0.], [1/1] [0., 0.]]) >>> y = ht.ones((3, 5), split=1) [0/1] tensor([[1., 1., 1.], [0/1] [1., 1., 1.], [0/1] [1., 1., 1.]]) [1/1] tensor([[1., 1.], [1/1] [1., 1.], [1/1] [1., 1.]]) >>> ht.concatenate((x, y), axis=0) [0/1] tensor([[0., 0., 0.], [0/1] [0., 0., 0.], [0/1] [0., 0., 0.], [0/1] [0., 0., 0.], [0/1] [1., 1., 1.], [0/1] [1., 1., 1.], [0/1] [1., 1., 1.]]) [1/1] tensor([[0., 0.], [1/1] [0., 0.], [1/1] [0., 0.], [1/1] [0., 0.], [1/1] [1., 1.], [1/1] [1., 1.], [1/1] [1., 1.]]) """ # input sanitation arrays = sanitation.sanitize_sequence(arrays) for arr in arrays: sanitation.sanitize_in(arr) # a single array cannot be concatenated if len(arrays) < 2: raise ValueError("concatenate requires 2 arrays") # concatenate multiple arrays elif len(arrays) > 2: res = concatenate((arrays[0], arrays[1]), axis=axis) for a in range(2, len(arrays)): res = concatenate((res, arrays[a]), axis=axis) return res # unpack the arrays arr0, arr1 = arrays if not isinstance(axis, int): raise TypeError(f"axis must be an integer, currently: {type(axis)}") axis = stride_tricks.sanitize_axis(arr0.gshape, axis) if arr0.ndim != arr1.ndim: raise ValueError("DNDarrays must have the same number of dimensions") if any(arr0.gshape[i] != arr1.gshape[i] for i in range(len(arr0.gshape)) if i != axis): raise ValueError( f"Arrays cannot be concatenated, shapes must be the same in every axis except the selected axis: {arr0.gshape}, {arr1.gshape}" ) # different communicators may not be concatenated if arr0.comm != arr1.comm: raise RuntimeError("Communicators of passed arrays mismatch.") # identify common data type is_mps = arr0.larray.is_mps or arr1.larray.is_mps out_dtype = types.promote_types(arr0.dtype, arr1.dtype) if is_mps and out_dtype == types.float64: warnings.warn("MPS does not support float64, using float32 instead") out_dtype = types.float32 if arr0.dtype != out_dtype: arr0 = out_dtype(arr0, device=arr0.device) if arr1.dtype != out_dtype: arr1 = out_dtype(arr1, device=arr1.device) s0, s1 = arr0.split, arr1.split # no splits, local concat if s0 is None and s1 is None: return factories.array( torch.cat((arr0.larray, arr1.larray), dim=axis), device=arr0.device, comm=arr0.comm, ) # non-matching splits when both arrays are split elif s0 != s1 and all([s is not None for s in [s0, s1]]): raise RuntimeError(f"DNDarrays given have differing split axes, arr0 {s0} arr1 {s1}") elif (s0 is None and s1 != axis) or (s1 is None and s0 != axis): _, _, arr0_slice = arr1.comm.chunk(arr0.shape, arr1.split) _, _, arr1_slice = arr0.comm.chunk(arr1.shape, arr0.split) out = factories.array( torch.cat((arr0.larray[arr0_slice], arr1.larray[arr1_slice]), dim=axis), dtype=out_dtype, is_split=s1 if s1 is not None else s0, device=arr1.device, comm=arr0.comm, ) return out elif s0 == s1 or any(s is None for s in [s0, s1]): if s0 != axis and all(s is not None for s in [s0, s1]): # the axis is different than the split axis, this case can be easily implemented # torch cat arrays together and return a new array that is_split out = factories.array( torch.cat((arr0.larray, arr1.larray), dim=axis), dtype=out_dtype, is_split=s0, device=arr0.device, comm=arr0.comm, ) return out else: t_arr0 = arr0.larray t_arr1 = arr1.larray # maps are created for where the data is and the output shape is calculated lshape_map = torch.zeros((2, arr0.comm.size, len(arr0.gshape)), dtype=torch.int) lshape_map[0, arr0.comm.rank, :] = torch.Tensor(arr0.lshape) lshape_map[1, arr0.comm.rank, :] = torch.Tensor(arr1.lshape) lshape_map_comm = arr0.comm.Iallreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) arr0_shape, arr1_shape = list(arr0.shape), list(arr1.shape) arr0_shape[axis] += arr1_shape[axis] out_shape = tuple(arr0_shape) # the chunk map is used to determine how much data should be on each process chunk_map = torch.zeros((arr0.comm.size, len(arr0.gshape)), dtype=torch.int) _, _, chk = arr0.comm.chunk(out_shape, s0 if s0 is not None else s1) for i in range(len(out_shape)): chunk_map[arr0.comm.rank, i] = chk[i].stop - chk[i].start chunk_map_comm = arr0.comm.Iallreduce(MPI.IN_PLACE, chunk_map, MPI.SUM) lshape_map_comm.Wait() chunk_map_comm.Wait() if s0 is not None: send_slice = [slice(None)] * arr0.ndim keep_slice = [slice(None)] * arr0.ndim # data is first front-loaded onto the first size/2 processes for spr in range(1, arr0.comm.size): if arr0.comm.rank == spr: for pr in range(spr): send_amt = abs((chunk_map[pr, axis] - lshape_map[0, pr, axis]).item()) send_amt = ( send_amt if send_amt < t_arr0.shape[axis] else t_arr0.shape[axis] ) if send_amt: send_slice[arr0.split] = slice(0, send_amt) keep_slice[arr0.split] = slice(send_amt, t_arr0.shape[axis]) send = arr0.comm.Isend( t_arr0[tuple(send_slice)].clone(), dest=pr, tag=pr + arr0.comm.size + spr, ) t_arr0 = t_arr0[tuple(keep_slice)].clone() send.Wait() for pr in range(spr): snt = abs((chunk_map[pr, s0] - lshape_map[0, pr, s0]).item()) snt = ( snt if snt < lshape_map[0, spr, axis] else lshape_map[0, spr, axis].item() ) if arr0.comm.rank == pr and snt: shp = list(arr0.gshape) shp[arr0.split] = snt data = torch.zeros( shp, dtype=out_dtype.torch_type(), device=arr0.device.torch_device ) arr0.comm.Recv(data, source=spr, tag=pr + arr0.comm.size + spr) t_arr0 = torch.cat((t_arr0, data), dim=arr0.split) lshape_map[0, pr, arr0.split] += snt lshape_map[0, spr, arr0.split] -= snt if s1 is not None: send_slice = [slice(None)] * arr0.ndim keep_slice = [slice(None)] * arr0.ndim # push the data backwards (arr1), making the data the proper size for arr1 on the last nodes # the data is "compressed" on np/2 processes. data is sent from for spr in range(arr0.comm.size - 1, -1, -1): if arr0.comm.rank == spr: for pr in range(arr0.comm.size - 1, spr, -1): # calculate the amount of data to send from the chunk map send_amt = abs((chunk_map[pr, axis] - lshape_map[1, pr, axis]).item()) send_amt = ( send_amt if send_amt < t_arr1.shape[axis] else t_arr1.shape[axis] ) if send_amt: send_slice[axis] = slice( t_arr1.shape[axis] - send_amt, t_arr1.shape[axis] ) keep_slice[axis] = slice(0, t_arr1.shape[axis] - send_amt) send = arr1.comm.Isend( t_arr1[tuple(send_slice)].clone(), dest=pr, tag=pr + arr1.comm.size + spr, ) t_arr1 = t_arr1[tuple(keep_slice)].clone() send.Wait() for pr in range(arr1.comm.size - 1, spr, -1): snt = abs((chunk_map[pr, axis] - lshape_map[1, pr, axis]).item()) snt = ( snt if snt < lshape_map[1, spr, axis] else lshape_map[1, spr, axis].item() ) if arr1.comm.rank == pr and snt: shp = list(arr1.gshape) shp[axis] = snt data = torch.zeros( shp, dtype=out_dtype.torch_type(), device=arr1.device.torch_device ) arr1.comm.Recv(data, source=spr, tag=pr + arr1.comm.size + spr) t_arr1 = torch.cat((data, t_arr1), dim=axis) lshape_map[1, pr, axis] += snt lshape_map[1, spr, axis] -= snt if s0 is None: arb_slice = [None] * len(arr1.shape) for c in range(len(chunk_map)): arb_slice[axis] = c # the chunk map is adjusted by subtracting what data is already in the correct place (the data from # arr1 is already correctly placed) i.e. the chunk map shows how much data is still needed on each # process, the local chunk_map[tuple(arb_slice)] -= lshape_map[tuple([1] + arb_slice)] # after adjusting arr1 need to now select the target data in arr0 on each node with a local slice if arr0.comm.rank == 0: lcl_slice: list[slice[Any, Any, Any]] | Any = [slice(None)] * arr0.ndim lcl_slice[axis] = slice(chunk_map[0, axis].item()) t_arr0 = t_arr0[tuple(lcl_slice)].clone().squeeze() ttl = chunk_map[0, axis].item() for en in range(1, arr0.comm.size): sz = chunk_map[en, axis] if arr0.comm.rank == en: lcl_slice = [slice(None)] * arr0.ndim lcl_slice[axis] = slice(ttl, sz.item() + ttl, 1) t_arr0 = t_arr0[lcl_slice].clone().squeeze() ttl += sz.item() if len(t_arr0.shape) < len(t_arr1.shape): t_arr0.unsqueeze_(axis) if s1 is None: arb_slice = [None] * len(arr0.shape) for c in range(len(chunk_map)): arb_slice[axis] = c chunk_map[tuple(arb_slice)] -= lshape_map[tuple([0] + arb_slice)] # get the desired data in arr1 on each node with a local slice if arr1.comm.rank == arr1.comm.size - 1: lcl_slice = [slice(None)] * arr1.ndim lcl_slice[axis] = slice( t_arr1.shape[axis] - chunk_map[-1, axis].item(), t_arr1.shape[axis], 1 ) t_arr1 = t_arr1[tuple(lcl_slice)].clone().squeeze() ttl = chunk_map[-1, axis].item() for en in range(arr1.comm.size - 2, -1, -1): sz = chunk_map[en, axis] if arr1.comm.rank == en: lcl_slice = [slice(None)] * arr1.ndim lcl_slice[axis] = slice( t_arr1.shape[axis] - (sz.item() + ttl), t_arr1.shape[axis] - ttl, 1 ) t_arr1 = t_arr1[lcl_slice].clone().squeeze() ttl += sz.item() if len(t_arr1.shape) < len(t_arr0.shape): t_arr1.unsqueeze_(axis) res = torch.cat((t_arr0, t_arr1), dim=axis) out = factories.array( res, is_split=s0 if s0 is not None else s1, dtype=out_dtype, device=arr0.device, comm=arr0.comm, ) return out
[docs] def diag(a: DNDarray, offset: int = 0) -> DNDarray: """ Extract a diagonal or construct a diagonal array. See the documentation for :func:`diagonal` for more information about extracting the diagonal. Parameters ---------- a: DNDarray The array holding data for creating a diagonal array or extracting a diagonal. If `a` is a 1-dimensional array, a diagonal 2d-array will be returned. If `a` is a n-dimensional array with n > 1 the diagonal entries will be returned in an n-1 dimensional array. offset: int, optional The offset from the main diagonal. Offset greater than zero means above the main diagonal, smaller than zero is below the main diagonal. See Also -------- :func:`diagonal` Examples -------- >>> import heat as ht >>> a = ht.array([1, 2]) >>> ht.diag(a) DNDarray([[1, 0], [0, 2]], dtype=ht.int64, device=cpu:0, split=None) >>> ht.diag(a, offset=1) DNDarray([[0, 1, 0], [0, 0, 2], [0, 0, 0]], dtype=ht.int64, device=cpu:0, split=None) >>> ht.equal(ht.diag(ht.diag(a)), a) True >>> a = ht.array([[1, 2], [3, 4]]) >>> ht.diag(a) DNDarray([1, 4], dtype=ht.int64, device=cpu:0, split=None) """ sanitation.sanitize_in(a) if len(a.shape) > 1: return diagonal(a, offset=offset) elif len(a.shape) < 1: raise ValueError("input array must be of dimension 1 or greater") if not isinstance(offset, int): raise ValueError("offset must be an integer, got", type(offset)) # 1-dimensional array, must be extended to a square diagonal matrix gshape = (a.shape[0] + abs(offset),) * 2 off, lshape, _ = a.comm.chunk(gshape, a.split) # This ensures that the data is on the correct nodes if offset > 0: padding = factories.empty( (offset,), dtype=a.dtype, split=None, device=a.device, comm=a.comm ) a = concatenate((a, padding)) indices_x = torch.arange(0, min(lshape[0], max(gshape[0] - off - offset, 0))) elif offset < 0: padding = factories.empty( (abs(offset),), dtype=a.dtype, split=None, device=a.device, comm=a.comm ) a = concatenate((padding, a)) indices_x = torch.arange( max(0, min(abs(offset) - off, lshape[0])), lshape[0], device=a.device.torch_device ) else: # Offset = 0 values on main diagonal indices_x = torch.arange(0, lshape[0], device=a.device.torch_device) indices_y = indices_x + off + offset a.balance_() local = torch.zeros(lshape, dtype=a.dtype.torch_type(), device=a.device.torch_device) local[indices_x, indices_y] = a.larray[indices_x] return factories.array(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)
[docs] def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDarray: """ Extract a diagonal of an n-dimensional array with n > 1. The returned array will be of dimension n-1. Parameters ---------- a: DNDarray The array of which the diagonal should be extracted. offset: int, optional The offset from the main diagonal. Offset greater than zero means above the main diagonal, smaller than zero is below the main diagonal. Default is 0 which means the main diagonal will be selected. dim1: int, optional First dimension with respect to which to take the diagonal. dim2: int, optional Second dimension with respect to which to take the diagonal. Examples -------- >>> import heat as ht >>> a = ht.array([[1, 2], [3, 4]]) >>> ht.diagonal(a) DNDarray([1, 4], dtype=ht.int64, device=cpu:0, split=None) >>> ht.diagonal(a, offset=1) DNDarray([2], dtype=ht.int64, device=cpu:0, split=None) >>> ht.diagonal(a, offset=-1) DNDarray([3], dtype=ht.int64, device=cpu:0, split=None) >>> a = ht.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) >>> ht.diagonal(a) DNDarray([[0, 6], [1, 7]], dtype=ht.int64, device=cpu:0, split=None) >>> ht.diagonal(a, dim2=2) DNDarray([[0, 5], [2, 7]], dtype=ht.int64, device=cpu:0, split=None) """ dim1, dim2 = stride_tricks.sanitize_axis(a.shape, (dim1, dim2)) if dim1 == dim2: raise ValueError("Dim1 and dim2 need to be different") if not isinstance(a, DNDarray): raise ValueError("a must be a DNDarray, got", type(a)) if not isinstance(offset, int): raise ValueError("offset must be an integer, got", type(offset)) shape = a.gshape ax1 = shape[dim1] ax2 = shape[dim2] # determine the number of diagonal elements that will be retrieved length = min(ax1, ax2 - offset) if offset >= 0 else min(ax2, ax1 + offset) # Remove dim1 and dim2 from shape and append resulting length shape = tuple([x for ind, x in enumerate(shape) if ind not in (dim1, dim2)]) + (length,) x, y = min(dim1, dim2), max(dim1, dim2) if a.split is None: split = None elif a.split < x < y: split = a.split elif x < a.split < y: split = a.split - 1 elif x < y < a.split: split = a.split - 2 else: split = len(shape) - 1 if a.split is None or a.split not in (dim1, dim2): result = torch.diagonal(a.larray, offset=offset, dim1=dim1, dim2=dim2).contiguous() else: vz = 1 if a.split == dim1 else -1 off, _, _ = a.comm.chunk(a.shape, a.split) result = torch.diagonal( a.larray, offset=offset + vz * off, dim1=dim1, dim2=dim2 ).contiguous() return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm)
[docs] def dsplit(x: Sequence[DNDarray, ...], indices_or_sections: Iterable) -> List[DNDarray, ...]: """ Split array into multiple sub-DNDarrays along the 3rd axis (depth). Returns a list of sub-DNDarrays as copies of parts of `x`. Parameters ---------- x : DNDarray DNDArray to be divided into sub-DNDarrays. indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 3rd axis. If such a split is not possible, an error is raised. If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 3rd axis the array is split. If an index exceeds the dimension of the array along the 3rd axis, an empty sub-DNDarray is returned correspondingly. Raises ------ ValueError If `indices_or_sections` is given as integer, but a split does not result in equal division. Notes ----- Please refer to the split documentation. dsplit is equivalent to split with axis=2, the array is always split along the third axis provided the array dimension is greater than or equal to 3. See Also -------- :func:`split` :func:`hsplit` :func:`vsplit` Examples -------- >>> x = ht.array(24).reshape((2, 3, 4)) >>> ht.dsplit(x, 2) [DNDarray([[[ 0, 1], [ 4, 5], [ 8, 9]], [[12, 13], [16, 17], [20, 21]]]), DNDarray([[[ 2, 3], [ 6, 7], [10, 11]], [[14, 15], [18, 19], [22, 23]]])] >>> ht.dsplit(x, [1, 4]) [DNDarray([[[ 0], [ 4], [ 8]], [[12], [16], [20]]]), DNDarray([[[ 1, 2, 3], [ 5, 6, 7], [ 9, 10, 11]], [[13, 14, 15], [17, 18, 19], [21, 22, 23]]]), DNDarray([])] """ return split(x, indices_or_sections, 2)
[docs] def expand_dims(a: DNDarray, axis: int) -> DNDarray: """ Expand the shape of an array. Insert a new axis that will appear at the axis position in the expanded array shape. Parameters ---------- a : DNDarray Input array to be expanded. axis : int Position in the expanded axes where the new axis is placed. Raises ------ ValueError If `axis` is not consistent with the available dimensions. Examples -------- >>> x = ht.array([1, 2]) >>> x.shape (2,) >>> y = ht.expand_dims(x, axis=0) >>> y array([[1, 2]]) >>> y.shape (1, 2) >>> y = ht.expand_dims(x, axis=1) >>> y array([[1], [2]]) >>> y.shape (2, 1) """ # sanitize input sanitation.sanitize_in(a) # track split axis split_bookkeeping = [None] * a.ndim if a.split is not None: split_bookkeeping[a.split] = "split" output_shape = list(a.shape) local_expansion = a.larray if isinstance(axis, (tuple, list)): # sanitize axis, introduce arbitrary dummy dimensions to model expansion axis = stride_tricks.sanitize_axis(a.shape + (1,) * len(axis), axis) for ax in axis: split_bookkeeping.insert(ax, None) output_shape.insert(ax, 1) local_expansion = local_expansion.unsqueeze(dim=ax) else: # sanitize axis, introduce arbitrary dummy dimensions to model expansion axis = stride_tricks.sanitize_axis(a.shape + (1,), axis) split_bookkeeping.insert(axis, None) output_shape.insert(axis, 1) local_expansion = local_expansion.unsqueeze(dim=axis) output_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None output_shape = tuple(output_shape) return DNDarray( local_expansion, output_shape, a.dtype, output_split, a.device, a.comm, a.balanced, )
DNDarray.expand_dims = lambda self, axis: expand_dims(self, axis) DNDarray.expand_dims.__doc__ = expand_dims.__doc__
[docs] def flatten(a: DNDarray) -> DNDarray: """ Flattens an array into one dimension. Parameters ---------- a : DNDarray Array to collapse Warning ---------- If `a.split>0`, the array must be redistributed along the first axis (see :func:`resplit`). See Also -------- :func:`ravel` Examples -------- >>> a = ht.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) >>> ht.flatten(a) DNDarray([1, 2, 3, 4, 5, 6, 7, 8], dtype=ht.int64, device=cpu:0, split=None) """ sanitation.sanitize_in(a) if a.split is None: return factories.array( torch.flatten(a.larray), dtype=a.dtype, is_split=None, device=a.device, comm=a.comm, ) if a.split > 0: a = resplit(a, 0) a = factories.array( torch.flatten(a.larray), dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm, ) a.balance_() return a
DNDarray.flatten = lambda self: flatten(self) DNDarray.flatten.__doc__ = flatten.__doc__
[docs] def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: """ Reverse the order of elements in an array along the given axis. The shape of the array is preserved, but the elements are reordered. Parameters ---------- a: DNDarray Input array to be flipped axis: int or Tuple[int,...] The axis or sequence of axes to be flipped See Also -------- :func:`fliplr` :func:`flipud` Examples -------- >>> a = ht.array([[0, 1], [2, 3]]) >>> ht.flip(a, [0]) DNDarray([[2, 3], [0, 1]], dtype=ht.int64, device=cpu:0, split=None) >>> b = ht.array([[0, 1, 2], [3, 4, 5]], split=1) >>> ht.flip(a, [0, 1]) (1/2) tensor([5,4,3]) (2/2) tensor([2,1,0]) """ # flip all dimensions if axis is None: axis = tuple(range(a.ndim)) # torch.flip only accepts tuples if isinstance(axis, int): axis = (axis,) elif isinstance(axis, list): axis = tuple(axis) axis = stride_tricks.sanitize_axis(a.shape, axis) flipped = torch.flip(a.larray, axis) if not a.is_distributed() or a.split not in axis: return factories.array( flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) # Need to redistribute tensors on split axis # Get local shapes old_lshape = a.lshape dest_proc = a.comm.size - 1 - a.comm.rank new_lshape = a.comm.sendrecv(old_lshape, dest=dest_proc, source=dest_proc) # Exchange local tensors req = a.comm.Isend(flipped, dest=dest_proc) received = torch.empty(new_lshape, dtype=a.larray.dtype, device=a.device.torch_device) a.comm.Recv(received, source=dest_proc) res = factories.array(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) res.balance_() # after swapping, first processes may be empty req.Wait() return res
[docs] def fliplr(a: DNDarray) -> DNDarray: """ Flip array in the left/right direction. If `a.ndim>2`, flip along dimension 1. Parameters ---------- a: DNDarray Input array to be flipped, must be at least 2-D See Also -------- :func:`flip` :func:`flipud` Examples -------- >>> a = ht.array([[0, 1], [2, 3]]) >>> ht.fliplr(a) DNDarray([[1, 0], [3, 2]], dtype=ht.int64, device=cpu:0, split=None) >>> b = ht.array([[0, 1, 2], [3, 4, 5]], split=0) >>> ht.fliplr(b) (1/2) tensor([[2, 1, 0]]) (2/2) tensor([[5, 4, 3]]) """ return flip(a, 1)
[docs] def flipud(a: DNDarray) -> DNDarray: """ Flip array in the up/down direction. Parameters ---------- a: DNDarray Input array to be flipped See Also -------- :func:`flip` :func:`fliplr` Examples -------- >>> a = ht.array([[0, 1], [2, 3]]) >>> ht.flipud(a) DNDarray([[2, 3], [0, 1]], dtype=ht.int64, device=cpu:0, split=None)) >>> b = ht.array([[0, 1, 2], [3, 4, 5]], split=0) >>> ht.flipud(b) (1/2) tensor([3,4,5]) (2/2) tensor([0,1,2]) """ return flip(a, 0)
[docs] def hsplit(x: DNDarray, indices_or_sections: Iterable) -> List[DNDarray, ...]: """ Split array into multiple sub-DNDarrays along the 2nd axis (horizontally/column-wise). Returns a list of sub-DNDarrays as copies of parts of `x`. Parameters ---------- x : DNDarray DNDArray to be divided into sub-DNDarrays. indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 2nd axis. If such a split is not possible, an error is raised. If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 2nd axis the array is split. If an index exceeds the dimension of the array along the 2nd axis, an empty sub-DNDarray is returned correspondingly. Raises ------ ValueError If `indices_or_sections` is given as integer, but a split does not result in equal division. Notes ----- Please refer to the split documentation. hsplit is nearly equivalent to split with axis=1, the array is always split along the second axis though, in contrary to split, regardless of the array dimension. See Also -------- :func:`split` :func:`dsplit` :func:`vsplit` Examples -------- >>> x = ht.arange(24).reshape((2, 4, 3)) >>> ht.hsplit(x, 2) [DNDarray([[[ 0, 1, 2], [ 3, 4, 5]], [[12, 13, 14], [15, 16, 17]]]), DNDarray([[[ 6, 7, 8], [ 9, 10, 11]], [[18, 19, 20], [21, 22, 23]]])] >>> ht.hsplit(x, [1, 3]) [DNDarray([[[ 0, 1, 2]], [[12, 13, 14]]]), DNDarray([[[ 3, 4, 5], [ 6, 7, 8]], [[15, 16, 17], [18, 19, 20]]]), DNDarray([[[ 9, 10, 11]], [[21, 22, 23]]])] """ sanitation.sanitize_in(x) if len(x.lshape) < 2: x = reshape(x, (1, x.lshape[0])) result = split(x, indices_or_sections, 1) result = [flatten(sub_array) for sub_array in result] else: result = split(x, indices_or_sections, 1) return result
[docs] def hstack(arrays: Sequence[DNDarray, ...]) -> DNDarray: """ Stack arrays in sequence horizontally (column-wise). This is equivalent to concatenation along the second axis, except for 1-D arrays where it concatenates along the first axis. Parameters ---------- arrays : Sequence[DNDarray, ...] The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length. See Also -------- :func:`concatenate` :func:`stack` :func:`vstack` :func:`column_stack` :func:`row_stack` Examples -------- >>> a = ht.array((1, 2, 3)) >>> b = ht.array((2, 3, 4)) >>> ht.hstack((a, b)).larray [0/1] tensor([1, 2, 3, 2, 3, 4]) [1/1] tensor([1, 2, 3, 2, 3, 4]) >>> a = ht.array((1, 2, 3), split=0) >>> b = ht.array((2, 3, 4), split=0) >>> ht.hstack((a, b)).larray [0/1] tensor([1, 2, 3]) [1/1] tensor([2, 3, 4]) >>> a = ht.array([[1], [2], [3]], split=0) >>> b = ht.array([[2], [3], [4]], split=0) >>> ht.hstack((a, b)).larray [0/1] tensor([[1, 2], [0/1] [2, 3]]) [1/1] tensor([[3, 4]]) """ arrays = list(arrays) axis = 1 all_vec = False if len(arrays) == 2 and all(len(x.gshape) == 1 for x in arrays): axis = 0 all_vec = True if not all_vec: for cn, arr in enumerate(arrays): if len(arr.gshape) == 1: arrays[cn] = arr.expand_dims(1) return concatenate(arrays, axis=axis)
[docs] def moveaxis( x: DNDarray, source: Union[int, Sequence[int]], destination: Union[int, Sequence[int]] ) -> DNDarray: """ Moves axes at the positions in `source` to new positions. Parameters ---------- x : DNDarray The input array. source : int or Sequence[int, ...] Original positions of the axes to move. These must be unique. destination : int or Sequence[int, ...] Destination positions for each of the original axes. These must also be unique. See Also -------- ~heat.core.linalg.basics.transpose Permute the dimensions of an array. Raises ------ TypeError If `source` or `destination` are not ints, lists or tuples. ValueError If `source` and `destination` do not have the same number of elements. Examples -------- >>> x = ht.zeros((3, 4, 5)) >>> ht.moveaxis(x, 0, -1).shape (4, 5, 3) >>> ht.moveaxis(x, -1, 0).shape (5, 3, 4) """ if isinstance(source, int): source = (source,) if isinstance(source, list): source = tuple(source) try: source = stride_tricks.sanitize_axis(x.shape, source) except TypeError: raise TypeError("'source' must be ints, lists or tuples.") if isinstance(destination, int): destination = (destination,) if isinstance(destination, list): destination = tuple(destination) try: destination = stride_tricks.sanitize_axis(x.shape, destination) except TypeError: raise TypeError("'destination' must be ints, lists or tuples.") if len(source) != len(destination): raise ValueError("'source' and 'destination' must have the same number of elements.") order = [n for n in range(x.ndim) if n not in source] for dest, src in sorted(zip(destination, source)): order.insert(dest, src) return linalg.transpose(x, order)
[docs] def pad( array: DNDarray, pad_width: Union[int, Sequence[Sequence[int, int], ...]], mode: str = "constant", constant_values: int = 0, ) -> DNDarray: """ Pads tensor with a specific value (default=0). (Not all dimensions supported) Parameters ---------- array : DNDarray Array to be padded pad_width: Union[int, Sequence[Sequence[int, int], ...]] Number of values padded to the edges of each axis. ((before_1, after_1),...(before_N, after_N)) unique pad widths for each axis. Determines how many elements are padded along which dimension.\n Shortcuts: - ((before, after),) or (before, after): before and after pad width for each axis. - (pad_width,) or int: before = after = pad width for all axes. Therefore: - pad last dimension: (padding_left, padding_right) - pad last 2 dimensions: ((padding_top, padding_bottom),(padding_left, padding_right)) - pad last 3 dimensions: ((padding_front, padding_back),(padding_top, padding_bottom),(paddling_left, padding_right) ) - ... (same pattern) mode : str, optional - 'constant' (default): Pads the input tensor boundaries with a constant value. This is available for arbitrary dimensions constant_values: Union[int, float, Sequence[Sequence[int,int], ...], Sequence[Sequence[float,float], ...]] Number or tuple of 2-element-sequences (containing numbers), optional (default=0) The fill values for each axis (1 tuple per axis). ((before_1, after_1), ... (before_N, after_N)) unique pad values for each axis. Shortcuts: - ((before, after),) or (before, after): before and after padding values for each axis. - (value,) or int: before = after = padding value for all axes. Notes ----- This function follows the principle of datatype integrity. Therefore, an array can only be padded with values of the same datatype. All values that violate this rule are implicitly cast to the datatype of the `DNDarray`. Examples -------- >>> a = torch.arange(2 * 3 * 4).reshape(2, 3, 4) >>> b = ht.array(a, split=0) Pad last dimension >>> c = ht.pad(b, (2, 1), constant_values=1) tensor([[[ 1, 1, 0, 1, 2, 3, 1], [ 1, 1, 4, 5, 6, 7, 1], [ 1, 1, 8, 9, 10, 11, 1]], [[ 1, 1, 12, 13, 14, 15, 1], [ 1, 1, 16, 17, 18, 19, 1], [ 1, 1, 20, 21, 22, 23, 1]]]) Pad last 2 dimensions >>> d = ht.pad(b, [(1, 0), (2, 1)]) DNDarray([[[ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 1, 2, 3, 0], [ 0, 0, 4, 5, 6, 7, 0], [ 0, 0, 8, 9, 10, 11, 0]], [[ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 12, 13, 14, 15, 0], [ 0, 0, 16, 17, 18, 19, 0], [ 0, 0, 20, 21, 22, 23, 0]]], dtype=ht.int64, device=cpu:0, split=0) Pad last 3 dimensions >>> e = ht.pad(b, ((2, 1), [1, 0], (2, 1))) DNDarray([[[ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0]], [[ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0]], [[ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 1, 2, 3, 0], [ 0, 0, 4, 5, 6, 7, 0], [ 0, 0, 8, 9, 10, 11, 0]], [[ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 12, 13, 14, 15, 0], [ 0, 0, 16, 17, 18, 19, 0], [ 0, 0, 20, 21, 22, 23, 0]], [[ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0]]], dtype=ht.int64, device=cpu:0, split=0) """ # early out if pad width is 0 if pad_width == 0: return array if not isinstance(array, DNDarray): raise TypeError(f"expected array to be a ht.DNDarray, but was {type(array)}") if not isinstance(mode, str): raise TypeError(f"expected mode to be a string, but was {type(mode)}") # shortcut int for all dimensions if isinstance(pad_width, int): pad = (pad_width,) * 2 * len(array.shape) elif not isinstance(pad_width, (tuple, list)): raise TypeError( f"expected pad_width to be an integer or a sequence (tuple or list), but was {type(pad_width)}" ) # shortcut one sequence within a sequence for all dimensions - ((before,after), ) = pad_width elif len(pad_width) == 1: if isinstance(pad_width[0], int): pad = (pad_width[0],) * 2 * len(array.shape) elif not (isinstance(pad_width[0], (tuple, list))): raise TypeError( f"For shortcut option '1 sequence for all dimensions', expected element within pad_width to be a tuple or list, but was {type(pad_width[0])}" ) elif len(pad_width[0]) == 2: pad = pad_width[0] * len(array.shape) else: raise ValueError( f"Pad_width {pad_width} invalid.\n Apart from shortcut options (--> documentation), " "each sequence within pad_width must contain 2 elements." ) # shortcut - one sequence for all dimensions - (before,after) = pad_width elif len(pad_width) == 2 and isinstance(pad_width[0], int) and isinstance(pad_width[1], int): pad_width = tuple(pad_width) pad = pad_width * len(array.shape) # no shortcut - padding of various dimensions else: if any(not (isinstance(pad_tuple, (tuple, list))) for pad_tuple in pad_width): raise TypeError( f"Invalid type for pad_width {pad_width}.\nApart from shortcut options (--> documentation)," "pad_width has to be a sequence of (2 elements) sequences (sequence=tuple or list)." ) pad = () # Transform numpy pad_width to torch pad (--> one tuple containing all padding spans) for pad_tuple in pad_width: if isinstance(pad_tuple, list): pad_tuple = tuple(pad_tuple) pad = pad_tuple + pad if len(pad) % 2 != 0: raise ValueError( f"Pad_width {pad_width} invalid.\n Apart from shortcut options (--> documentation), " "each sequence within pad_width must contain 2 elements." ) if len(pad) // 2 > len(array.shape): raise ValueError( f"Not enough dimensions to pad.\n" f"Padding a {len(array.shape)}-dimensional tensor for {len(pad) // 2}" f" dimensions is not possible." ) # value_tuple = all padding values stored in 1 tuple if isinstance(constant_values, (tuple, list)): value_tuple = () # sequences for each dimension defined within one sequence if isinstance(constant_values[0], (tuple, list)): # one sequence for all dimensions - values = ((before, after),) if len(constant_values) == 1: value_tuple = constant_values[0] * (len(pad) // 2) else: for value_pair in constant_values: if isinstance(value_pair, tuple): pass elif isinstance(value_pair, list): value_pair = tuple(value_pair) else: raise TypeError( f"Value pair {value_pair} within values invalid. Expected all elements within values to be sequences(list/tuple)," f"but one was: {type(value_pair)}" ) value_tuple = value_pair + value_tuple if len(value_tuple) % 2 != 0: raise ValueError( f"Expected values to contain an even amount of elements, but got {len(value_tuple)}" ) # One sequence for all dimensions - values = (before, after) elif len(constant_values) == 2: value_tuple = constant_values * (len(pad) // 2) rank_array = len(array.shape) amount_pad_dim = len(pad) // 2 pad_dim = [rank_array - i for i in range(1, amount_pad_dim + 1)] array_torch = array.larray if array.split is not None: counts = array.comm.counts_displs_shape(array.gshape, array.split)[0] amount_of_processes = len(counts) # calculate gshape for output tensor output_shape_list = list(array.gshape) for i in range(0, len(pad), 2): output_shape_list[-((i // 2) + 1)] += sum(pad[i : i + 2]) output_shape = tuple(output_shape_list) # ------------------------------------------------------------------------------------------------------------------- # CASE 1: Padding in non split dimension or no distribution at all # ------------------------------------------------------------------------------------------------------------------ # no data if 0 in list(array.lshape): adapted_lshape_list = [ 0 if i == array.split else output_shape[i] for i in range(len(output_shape)) ] adapted_lshape = tuple(adapted_lshape_list) padded_torch_tensor = torch.empty( adapted_lshape, dtype=array._DNDarray__array.dtype, device=array.device.torch_device ) else: if array.split is None or array.split not in pad_dim or amount_of_processes == 1: # values = scalar if isinstance(constant_values, int) or isinstance(constant_values, float): padded_torch_tensor = torch.nn.functional.pad( array_torch, pad, mode, constant_values ) # values = sequence with one value for all dimensions elif len(constant_values) == 1 and ( isinstance(constant_values[0], int) or isinstance(constant_values[0], float) ): padded_torch_tensor = torch.nn.functional.pad( array_torch, pad, mode, constant_values[0] ) else: padded_torch_tensor = array_torch for i in range(len(value_tuple) - 1, -1, -1): pad_list = [0] * 2 * rank_array pad_list[i] = pad[i] pad_tuple = tuple(pad_list) padded_torch_tensor = torch.nn.functional.pad( padded_torch_tensor, pad_tuple, mode, value_tuple[i] ) else: # ------------------------------------------------------------------------------------------------------------------ # CASE 2: padding in split dimension and function runs on more than 1 process # # Pad only first/last tensor portion on node (i.e. only beginning/end in split dimension) # --> "Calculate" pad tuple for the corresponding tensor portion/ the two indices which have to be set to zero # in different paddings depending on the dimension # Calculate the index of the first element in tuple that has to change/set to zero in # some dimensions (the following is the second) # ------------------------------------------------------------------------------------------------------------------ pad_beginning_list = list(pad) pad_end_list = list(pad) pad_middle_list = list(pad) # calculate the corresponding pad tuples first_idx_set_zero = 2 * (rank_array - array.split - 1) pad_end_list[first_idx_set_zero] = 0 pad_beginning_list[first_idx_set_zero + 1] = 0 pad_middle_list[first_idx_set_zero : first_idx_set_zero + 2] = [0, 0] pad_beginning = tuple(pad_beginning_list) pad_end = tuple(pad_end_list) pad_middle = tuple(pad_middle_list) if amount_of_processes >= array.shape[array.split]: last_ps_with_data = array.shape[array.split] - 1 else: last_ps_with_data = amount_of_processes - 1 rank = array.comm.rank # first process - pad beginning if rank == 0: pad_tuple_curr_rank = pad_beginning # last process - pad end elif rank == last_ps_with_data: pad_tuple_curr_rank = pad_end # pad middle else: pad_tuple_curr_rank = pad_middle if isinstance(constant_values, (int, float)): padded_torch_tensor = torch.nn.functional.pad( array_torch, pad_tuple_curr_rank, mode, constant_values ) elif len(constant_values) == 1 and isinstance(constant_values[0], (int, float)): padded_torch_tensor = torch.nn.functional.pad( array_torch, pad_tuple_curr_rank, mode, constant_values[0] ) else: padded_torch_tensor = array_torch for i in range(len(value_tuple) - 1, -1, -1): pad_list = [0] * 2 * rank_array pad_list[i] = pad_tuple_curr_rank[i] pad_tuple = tuple(pad_list) padded_torch_tensor = torch.nn.functional.pad( padded_torch_tensor, pad_tuple, mode, value_tuple[i] ) padded_tensor = factories.array( padded_torch_tensor, dtype=array.dtype, is_split=array.split, device=array.device, comm=array.comm, ) padded_tensor.balance_() return padded_tensor
[docs] def ravel(a: DNDarray) -> DNDarray: """ Return a flattened view of `a` if possible. A copy is returned otherwise. Parameters ---------- a : DNDarray array to collapse Notes ----- Returning a view of distributed data is only possible when `split != 0`. The returned DNDarray may be unbalanced. Otherwise, data must be communicated among processes, and `ravel` falls back to `flatten`. See Also -------- :func:`flatten` Examples -------- >>> a = ht.ones((2, 3), split=0) >>> b = ht.ravel(a) >>> a[0, 0] = 4 >>> b DNDarray([4., 1., 1., 1., 1., 1.], dtype=ht.float32, device=cpu:0, split=0) """ sanitation.sanitize_in(a) if a.split is None: return factories.array( torch.flatten(a._DNDarray__array), dtype=a.dtype, copy=False, is_split=None, device=a.device, comm=a.comm, ) # Redistribution necessary if a.split != 0: return flatten(a) result = factories.array( torch.flatten(a._DNDarray__array), dtype=a.dtype, copy=False, is_split=a.split, device=a.device, comm=a.comm, ) return result
[docs] def redistribute( arr: DNDarray, lshape_map: torch.Tensor = None, target_map: torch.Tensor = None ) -> DNDarray: """ Redistributes the data of the :class:`DNDarray` *along the split axis* to match the given target map. This function does not modify the non-split dimensions of the ``DNDarray``. This is an abstraction and extension of the balance function. Parameters ---------- arr: DNDarray DNDarray to redistribute lshape_map : torch.Tensor, optional The current lshape of processes. Units are ``[rank, lshape]``. target_map : torch.Tensor, optional The desired distribution across the processes. Units are ``[rank, target lshape]``. Note: the only important parts of the target map are the values along the split axis, values which are not along this axis are there to mimic the shape of the ``lshape_map``. Examples -------- >>> st = ht.ones((50, 81, 67), split=2) >>> target_map = torch.zeros((st.comm.size, 3), dtype=torch.int64) >>> target_map[0, 2] = 67 >>> print(target_map) [0/2] tensor([[ 0, 0, 67], [0/2] [ 0, 0, 0], [0/2] [ 0, 0, 0]], dtype=torch.int32) [1/2] tensor([[ 0, 0, 67], [1/2] [ 0, 0, 0], [1/2] [ 0, 0, 0]], dtype=torch.int32) [2/2] tensor([[ 0, 0, 67], [2/2] [ 0, 0, 0], [2/2] [ 0, 0, 0]], dtype=torch.int32) >>> print(st.lshape) [0/2] (50, 81, 23) [1/2] (50, 81, 22) [2/2] (50, 81, 22) >>> ht.redistribute_(st, target_map=target_map) >>> print(st.lshape) [0/2] (50, 81, 67) [1/2] (50, 81, 0) [2/2] (50, 81, 0) """ arr2 = arr.copy() arr2.redistribute_(lshape_map=lshape_map, target_map=target_map) return arr2
DNDarray.redistribute = lambda arr, lshape_map=None, target_map=None: redistribute( arr, lshape_map, target_map ) DNDarray.redistribute.__doc__ = redistribute.__doc__
[docs] def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarray: """ Creates a new `DNDarray` by repeating elements of array `a`. The output has the same shape as `a`, except along the given axis. If axis is None, this function returns a flattened `DNDarray`. Parameters ---------- a : array_like (i.e. int, float, or tuple/ list/ np.ndarray/ ht.DNDarray of ints/floats) Array containing the elements to be repeated. repeats : int, or 1-dimensional/ DNDarray/ np.ndarray/ list/ tuple of ints The number of repetitions for each element, indicates broadcast if int or array_like of 1 element. In this case, the given value is broadcasted to fit the shape of the given axis. Otherwise, its length must be the same as a in the specified axis. To put it differently, the amount of repetitions has to be determined for each element in the corresponding dimension (or in all dimensions if axis is None). axis: int, optional The axis along which to repeat values. By default, use the flattened input array and return a flat output array. Examples -------- >>> ht.repeat(3, 4) DNDarray([3, 3, 3, 3]) >>> x = ht.array([[1, 2], [3, 4]]) >>> ht.repeat(x, 2) DNDarray([1, 1, 2, 2, 3, 3, 4, 4]) >>> x = ht.array([[1, 2], [3, 4]]) >>> ht.repeat(x, [0, 1, 2, 0]) DNDarray([2, 3, 3]) >>> ht.repeat(x, [1, 2], axis=0) DNDarray([[1, 2], [3, 4], [3, 4]]) """ # sanitation `a` if not isinstance(a, DNDarray): if isinstance(a, (int, float)): a = factories.array([a]) elif isinstance(a, (tuple, list, np.ndarray)): a = factories.array(a) else: raise TypeError( f"`a` must be a ht.DNDarray, np.ndarray, list, tuple, integer, or float, currently: {type(a)}" ) # sanitation `axis` if axis is not None and not isinstance(axis, int): raise TypeError(f"`axis` must be an integer or None, currently: {type(axis)}") if axis is not None and (axis >= len(a.shape) or axis < 0): raise ValueError( f"Invalid input for `axis`. Value has to be either None or between 0 and {len(a.shape) - 1}, not {axis}." ) # sanitation `repeats` if not isinstance(repeats, (int, list, tuple, np.ndarray, DNDarray)): raise TypeError( f"`repeats` must be an integer, list, tuple, np.ndarray or ht.DNDarray of integers, currently: {type(repeats)}" ) # no broadcast implied if not isinstance(repeats, int): # make sure everything inside `repeats` is int if isinstance(repeats, DNDarray): if repeats.dtype == types.int64: pass elif types.can_cast(repeats.dtype, types.int64): repeats = factories.array( repeats, dtype=types.int64, is_split=repeats.split, device=repeats.device, comm=repeats.comm, ) else: raise TypeError( f"Invalid dtype for ht.DNDarray `repeats`. Has to be integer, but was {repeats.dtype}" ) elif isinstance(repeats, np.ndarray): if not types.can_cast(repeats.dtype.type, types.int64): raise TypeError( f"Invalid dtype for np.ndarray `repeats`. Has to be integer, but was {repeats.dtype.type}" ) repeats = factories.array( repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm ) elif not all(isinstance(r, int) for r in repeats): raise TypeError( "Invalid type within `repeats`. All components of `repeats` must be integers." ) else: repeats = factories.array( repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm ) # check `repeats` is not empty if repeats.gnumel == 0: raise ValueError("Invalid input for `repeats`. `repeats` must contain data.") # check `repeats` is 1-dimensional if len(repeats.shape) != 1: raise ValueError( f"Invalid input for `repeats`. `repeats` must be a 1d-object or integer, but was {len(repeats.shape)}-dimensional." ) # start of algorithm if 0 in a.gshape: return a # Broadcast (via int or 1-element DNDarray) if isinstance(repeats, int) or repeats.gnumel == 1: if axis is None and a.split is not None and a.split != 0: warnings.warn( f"If axis is None, `a` has to be split along axis 0 (not {a.split}) if distributed.\n`a` will be copied with new split axis 0." ) a = resplit(a, 0) if isinstance(repeats, int): repeated_array_torch = torch.repeat_interleave(a._DNDarray__array, repeats, axis) else: if repeats.split is not None: warnings.warn( "For broadcast via array_like repeats, `repeats` must not be " "distributed (along axis {}).\n`repeats` will be " "copied with new split axis None.".format(repeats.split) ) repeats = resplit(repeats, None) repeated_array_torch = torch.repeat_interleave( a._DNDarray__array, repeats._DNDarray__array, axis ) # No broadcast else: # check if the data chunks of `repeats` and/or `a` have to be (re)distributed before call of torch function. # UNDISTRIBUTED CASE (a not distributed) if a.split is None: if repeats.split is not None: warnings.warn( f"If `a` is undistributed, `repeats` also has to be undistributed (not split along axis {repeats.split}).\n`repeats` will be copied " "with new split axis None." ) repeats = resplit(repeats, None) # Check correct input if axis is None: # check matching shapes (repetition defined for every element) if a.gnumel != repeats.gnumel: raise ValueError( f"Invalid input. Sizes of flattened `a` ({a.gnumel}) and `repeats` ({repeats.gnumel}) are not same. " "Please revise your definition specifying repetitions for all elements of the DNDarray `a` " "or replace repeats with a single scalar." ) # axis is not None elif a.lshape[axis] != repeats.lnumel: raise ValueError( f"Invalid input. Amount of elements of `repeats` ({repeats.lnumel}) and of `a` in the specified axis ({a.lshape[axis]}) " "are not the same. Please revise your definition specifying repetitions for all elements " "of the DNDarray `a` or replace `repeats` with a single scalar" ) # DISTRIBUTED CASE (a distributed) elif axis is None: if a.gnumel != repeats.gnumel: raise ValueError( f"Invalid input. Sizes of flattened `a` ({a.gnumel}) and `repeats` ({repeats.gnumel}) are not same. " "Please revise your definition specifying repetitions for all elements of the DNDarray `a` " "or replace `repeats` with a single scalar." ) if a.split != 0: warnings.warn( f"If `axis` is None, `a` has to be split along axis 0 (not {a.split}) if distributed.\n`a` will be copied " "with new split axis 0." ) a = resplit(a, 0) repeats = repeats.reshape(a.gshape) if repeats.split != 0: warnings.warn( f"If `axis` is None, `repeats` has to be split along axis 0 (not {repeats.split}) if distributed.\n`repeats` will be copied " "with new split axis 0." ) repeats = resplit(repeats, 0) flatten_repeats_t = torch.flatten(repeats._DNDarray__array) repeats = factories.array( flatten_repeats_t, is_split=repeats.split, device=repeats.device, comm=repeats.comm, ) # axis is not None elif a.split == axis: if repeats.split != 0: warnings.warn( f"If `axis` equals `a.split`, `repeats` has to be split along axis 0 (not {repeats.split}) if distributed.\n`repeats` will be copied " "with new split axis 0" ) repeats = resplit(repeats, 0) # a.split != axis else: if repeats.split is not None: warnings.warn( f"If `axis` != `a.split`, `repeast` must not be distributed (along axis {repeats.split}).\n`repeats` will be copied with new " "split axis None." ) repeats = resplit(repeats, None) if a.lshape[axis] != repeats.lnumel: raise ValueError( f"Invalid input. Amount of elements of `repeats` ({repeats.lnumel}) and of `a` in the specified axis ({a.lshape[axis]}) " "are not the same. Please revise your definition specifying repetitions for all elements " "of the DNDarray `a` or replace `repeats` with a single scalar" ) repeated_array_torch = torch.repeat_interleave( a._DNDarray__array, repeats._DNDarray__array, axis ) repeated_array = factories.array( repeated_array_torch, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) repeated_array.balance_() return repeated_array
[docs] def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDarray: """ Returns an array with the same data and number of elements as `a`, but with the specified shape. Parameters ---------- a : DNDarray The input array shape : Union[int, Tuple[int,...]] Shape of the new array. Must be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions. new_split : int, optional The distribution axis of the reshaped array. If `new_split` is not provided, the reshaped array will have: - the same split axis as the input array, if the original dimensionality is unchanged; - split axis 0, if the number of dimensions is modified by reshaping. **kwargs Extra keyword arguments. Raises ------ ValueError If the number of elements in the new shape is inconsistent with the input data. Notes ----- `reshape()` might require significant communication among processes. Communication is minimized if the input array is distributed along axis 0, i.e. `a.split == 0`. See Also -------- :func:`ravel` Examples -------- >>> a = ht.zeros((3, 4)) >>> ht.reshape(a, (4, 3)) DNDarray([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=ht.float32, device=cpu:0, split=None) >>> a = ht.linspace(0, 14, 8, split=0) >>> ht.reshape(a, (2, 4)) (1/2) tensor([[0., 2., 4., 6.]]) (2/2) tensor([[ 8., 10., 12., 14.]]) # 3-dim array, distributed along axis 1 >>> a = ht.random.rand(2, 3, 4, split=1) >>> a DNDarray([[[0.5525, 0.5434, 0.9477, 0.9503], [0.4165, 0.3924, 0.3310, 0.3935], [0.1008, 0.1750, 0.9030, 0.8579]], [[0.0680, 0.4944, 0.4114, 0.6669], [0.6423, 0.2625, 0.5413, 0.2225], [0.0197, 0.5079, 0.4739, 0.4387]]], dtype=ht.float32, device=cpu:0, split=1) >>> a.reshape(-1, 3) # reshape to 2-dim array: split axis will be set to 0 DNDarray([[0.5525, 0.5434, 0.9477], [0.9503, 0.4165, 0.3924], [0.3310, 0.3935, 0.1008], [0.1750, 0.9030, 0.8579], [0.0680, 0.4944, 0.4114], [0.6669, 0.6423, 0.2625], [0.5413, 0.2225, 0.0197], [0.5079, 0.4739, 0.4387]], dtype=ht.float32, device=cpu:0, split=0) >>> a.reshape(2, 3, 2, 2, new_split=1) # reshape to 4-dim array, specify distribution axis DNDarray([[[[0.5525, 0.5434], [0.9477, 0.9503]], [[0.4165, 0.3924], [0.3310, 0.3935]], [[0.1008, 0.1750], [0.9030, 0.8579]]], [[[0.0680, 0.4944], [0.4114, 0.6669]], [[0.6423, 0.2625], [0.5413, 0.2225]], [[0.0197, 0.5079], [0.4739, 0.4387]]]], dtype=ht.float32, device=cpu:0, split=1) """ if not isinstance(a, DNDarray): raise TypeError(f"'a' must be a DNDarray, currently {type(a)}") # use numpys _ShapeLike but expand to handle torch and heat Tensors np_proxy = np.lib.stride_tricks.as_strided(np.ones(1), a.gshape, [0] * a.ndim, writeable=False) try: np_proxy.reshape(shape) # numpy defines their own _ShapeLike except TypeError as e: # handle Tensors and DNDarrays try: # make shape a np.ndarray if len(shape) == 1: shape = shape[0] if hasattr(shape, "cpu"): # move to cpu shape = shape.cpu() if hasattr(shape, "detach"): # torch.Tensors have to detach before numpy call shape = shape.detach() if hasattr(shape, "numpy"): # for DNDarrays shape = shape.numpy() else: # Try to coerce everything else. shape = np.asarray(shape) except Exception: raise TypeError(e) shape = np_proxy.reshape(shape).shape # sanitized shape according to numpy # tdtype, tdevice = a.dtype.torch_type(), a.device.torch_device tdevice = a.device.torch_device # always operate along split axis 0 orig_split = a.split if a.split is not None: a.resplit_(axis=0) local_a = a.larray # check new_split parameter new_split = kwargs.get("new_split") if new_split is None: if orig_split is not None and len(shape) != a.ndim: # dimensionality reduced or expanded # set output split axis to 0 new_split = 0 else: new_split = orig_split new_split = stride_tricks.sanitize_axis(shape, new_split) if not a.is_distributed(): if a.comm.size > 1 and new_split is not None: # keep local slice only _, _, local_slice = a.comm.chunk(shape, new_split) _ = local_a.reshape(shape) local_a = _[local_slice].contiguous() del _ else: local_a = local_a.reshape(shape) return DNDarray( local_a, gshape=shape, dtype=a.dtype, split=new_split, device=a.device, comm=a.comm, balanced=True, ) lshape_map = a.lshape_map rank = a.comm.rank size = a.comm.size # flatten dimensions from split axis on, e.g. if split = 1, (3,4,5) -> (3,20) local_elements_off_split = torch.prod(lshape_map[:, a.split :], dim=1) first_local_shape = tuple(lshape_map[rank, : a.split].tolist()) + ( local_elements_off_split[rank].item(), ) local_a = local_a.reshape(first_local_shape) first_global_shape = tuple(lshape_map[rank, : a.split].tolist()) + ( local_elements_off_split.sum().item(), ) reshape_first_pass = DNDarray( local_a, gshape=first_global_shape, dtype=a.dtype, split=a.split, device=a.device, comm=a.comm, balanced=a.balanced, ) new_lshape_map = lshape_map[:, : a.split + 1] new_lshape_map[:, a.split] = local_elements_off_split reshape_first_pass.__lshape_map = new_lshape_map # redistribute if necessary. All splits but a.split = 0 have been ruled out lshape_map = reshape_first_pass.lshape_map current_numel = torch.prod(lshape_map, dim=1) # calculate target number of elements on each rank target_numel = torch.zeros((size, len(shape)), dtype=torch.int64, device=tdevice) for i in range(size): _, local_shape, _ = a.comm.chunk(shape, a.split, rank=i) target_numel[i] = torch.tensor(local_shape) if i == rank: second_local_shape = local_shape target_numel = torch.prod(target_numel, dim=1) if (target_numel == current_numel).all(): local_a = local_a.reshape(second_local_shape) else: # redistribution is necessary before reshaping target_map = lshape_map.clone() target_map[:, a.split] = target_numel reshape_first_pass.redistribute_(target_map=target_map) local_a = reshape_first_pass.larray.reshape(second_local_shape) reshape_final = DNDarray( local_a, gshape=shape, dtype=a.dtype, split=0, device=a.device, comm=a.comm, balanced=None, ) reshape_final.resplit_(axis=new_split) return reshape_final
DNDarray.reshape = lambda self, *shape, **kwargs: reshape(self, *shape, **kwargs) DNDarray.reshape.__doc__ = reshape.__doc__
[docs] def roll( x: DNDarray, shift: Union[int, Tuple[int]], axis: Optional[Union[int, Tuple[int]]] = None ) -> DNDarray: """ Rolls array elements along a specified axis. Array elements that roll beyond the last position are re-introduced at the first position. Array elements that roll beyond the first position are re-introduced at the last position. Parameters ---------- x : DNDarray input array shift : Union[int, Tuple[int, ...]] number of places by which the elements are shifted. If 'shift' is a tuple, then 'axis' must be a tuple of the same size, and each of the given axes is shifted by the corrresponding element in 'shift'. If 'shift' is an `int` and 'axis' a `tuple`, then the same shift is used for all specified axes. axis : Optional[Union[int, Tuple[int, ...]]] axis (or axes) along which elements to shift. If 'axis' is `None`, the array is flattened, shifted, and then restored to its original shape. Default: `None`. Raises ------ TypeError If 'shift' or 'axis' is not of type `int`, `list` or `tuple`. ValueError If 'shift' and 'axis' are tuples with different sizes. Examples -------- >>> a = ht.arange(20).reshape((4, 5)) >>> a DNDarray([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]], dtype=ht.int32, device=cpu:0, split=None) >>> ht.roll(a, 1) DNDarray([[19, 0, 1, 2, 3], [ 4, 5, 6, 7, 8], [ 9, 10, 11, 12, 13], [14, 15, 16, 17, 18]], dtype=ht.int32, device=cpu:0, split=None) >>> ht.roll(a, -1, 0) DNDarray([[ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [ 0, 1, 2, 3, 4]], dtype=ht.int32, device=cpu:0, split=None) """ sanitation.sanitize_in(x) if isinstance(axis, list): axis = tuple(axis) axis = stride_tricks.sanitize_axis(x.shape, axis) if axis is None: return roll(x.flatten(), shift, 0).reshape(x.shape, new_split=x.split) # inputs are ints if isinstance(shift, int): if isinstance(axis, int): if not x.is_distributed(): return DNDarray( torch.roll(x.larray, shift, axis), gshape=x.shape, dtype=x.dtype, split=x.split, device=x.device, comm=x.comm, balanced=x.balanced, ) # x is distributed if axis == x.split: # roll along split axis size = x.comm.Get_size() rank = x.comm.Get_rank() # local elements along axis: lshape_map = x.create_lshape_map(force_check=False)[:, x.split] cumsum_map = torch.cumsum(lshape_map, dim=0) # cumulate along axis indices = torch.arange(size, device=x.device.torch_device) index_map = torch.repeat_interleave(indices, lshape_map) # index -> process # compute index positions index_old = torch.arange(lshape_map[rank], device=x.device.torch_device) if rank > 0: index_old += cumsum_map[rank - 1] send_index = (index_old + shift) % x.gshape[x.split] recv_index = (index_old - shift) % x.gshape[x.split] # exchange arrays recv = torch.empty_like(x.larray) recv_splits = torch.split(recv, 1, dim=x.split) recv_requests = [None for i in range(x.lshape[x.split])] for i in range(x.lshape[x.split]): recv_requests[i] = x.comm.Irecv( recv_splits[i], index_map[recv_index[i]], index_old[i] ) send_splits = torch.split(x.larray, 1, dim=x.split) send_requests = [None for i in range(x.lshape[x.split])] for i in range(x.lshape[x.split]): send_requests[i] = x.comm.Isend( send_splits[i], index_map[send_index[i]], send_index[i] ) for i in range(x.lshape[x.split]): recv_requests[i].Wait() for i in range(x.lshape[x.split]): send_requests[i].Wait() return DNDarray(recv, x.gshape, x.dtype, x.split, x.device, x.comm, x.balanced) else: # pytorch does not support int / sequence combo at the time, make shift a list instead try: axis = sanitation.sanitize_sequence(axis) except TypeError: raise TypeError(f"axis must be a int, list or a tuple, got {type(axis)}") shift = [shift] * len(axis) if not x.is_distributed(): return DNDarray( torch.roll(x.larray, shift, axis), gshape=x.shape, dtype=x.dtype, split=x.split, device=x.device, comm=x.comm, balanced=x.balanced, ) # x is distributed return roll(x, shift, axis) else: # input must be tuples now try: shift = sanitation.sanitize_sequence(shift) except TypeError: raise TypeError(f"shift must be an integer, list or a tuple, got {type(shift)}") try: axis = sanitation.sanitize_sequence(axis) except TypeError: raise TypeError(f"axis must be an integer, list or a tuple, got {type(axis)}") if len(shift) != len(axis): raise ValueError( f"shift and axis length must be the same, got {len(shift)} and {len(axis)}" ) for i in range(len(shift)): if not isinstance(shift[i], int): raise TypeError(f"Element {i} in shift is not an integer, got {type(shift[i])}") if not isinstance(axis[i], int): raise TypeError(f"Element {i} in axis is not an integer, got {type(axis[i])}") if not x.is_distributed(): return DNDarray( torch.roll(x.larray, shift, axis), gshape=x.shape, dtype=x.dtype, split=x.split, device=x.device, comm=x.comm, balanced=x.balanced, ) # x is distributed if x.split in axis: # remove split axis elements shift_split = 0 for y in (x.split, x.split - x.ndim): idx = [i for i in range(len(axis)) if axis[i] == y] for i in idx: shift_split += shift[i] for i in reversed(idx): axis.remove(y) del shift[i] # compute new array along split axis x = roll(x, shift_split, x.split) if len(axis) == 0: return x # use PyTorch for all other axes rolled = torch.roll(x.larray, shift, axis) return DNDarray( rolled, gshape=x.shape, dtype=x.dtype, split=x.split, device=x.device, comm=x.comm, balanced=x.balanced, )
[docs] def rot90(m: DNDarray, k: int = 1, axes: Sequence[int, int] = (0, 1)) -> DNDarray: """ Rotate an array by 90 degrees in the plane specified by `axes`. Rotation direction is from the first towards the second axis. Parameters ---------- m : DNDarray Array of two or more dimensions. k : integer Number of times the array is rotated by 90 degrees. axes: (2,) Sequence[int, int] The array is rotated in the plane defined by the axes. Axes must be different. Raises ------ ValueError If `len(axis)!=2`. ValueError If the axes are the same. ValueError If axes are out of range. Notes ----- - ``rot90(m, k=1, axes=(1,0))`` is the reverse of ``rot90(m, k=1, axes=(0,1))``.\n - ``rot90(m, k=1, axes=(1,0))`` is equivalent to ``rot90(m, k=-1, axes=(0,1))``. May change the split axis on distributed tensors. Examples -------- >>> m = ht.array([[1, 2], [3, 4]], dtype=ht.int) >>> m DNDarray([[1, 2], [3, 4]], dtype=ht.int32, device=cpu:0, split=None) >>> ht.rot90(m) DNDarray([[2, 4], [1, 3]], dtype=ht.int32, device=cpu:0, split=None) >>> ht.rot90(m, 2) DNDarray([[4, 3], [2, 1]], dtype=ht.int32, device=cpu:0, split=None) >>> m = ht.arange(8).reshape((2, 2, 2)) >>> ht.rot90(m, 1, (1, 2)) DNDarray([[[1, 3], [0, 2]], [[5, 7], [4, 6]]], dtype=ht.int32, device=cpu:0, split=None) """ axes = tuple(axes) if len(axes) != 2: raise ValueError("len(axes) must be 2.") if not isinstance(m, DNDarray): raise TypeError(f"expected m to be a ht.DNDarray, but was {type(m)}") if axes[0] == axes[1] or np.absolute(axes[0] - axes[1]) == m.ndim: raise ValueError("Axes must be different.") if axes[0] >= m.ndim or axes[0] < -m.ndim or axes[1] >= m.ndim or axes[1] < -m.ndim: raise ValueError(f"Axes={axes} out of range for array of ndim={m.ndim}.") if m.split is None: return factories.array( torch.rot90(m.larray, k, axes), dtype=m.dtype, device=m.device, comm=m.comm ) try: k = int(k) except (TypeError, ValueError): raise TypeError("Unknown type, must be castable to integer") k %= 4 if k == 0: return m.copy() if k == 2: return flip(flip(m, axes[0]), axes[1]) axes_list = np.arange(0, m.ndim).tolist() (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]], axes_list[axes[0]]) if k == 1: return linalg.transpose(flip(m, axes[1]), axes_list) else: # k == 3 return flip(linalg.transpose(m, axes_list), axes[1])
DNDarray.rot90 = lambda self, k=1, axis=(0, 1): rot90(self, k, axis) DNDarray.rot90.__doc__ = rot90.__doc__
[docs] def shape(a: DNDarray) -> Tuple[int, ...]: """ Returns the global shape of a (potentially distributed) `DNDarray` as a tuple. Parameters ---------- a : DNDarray The input `DNDarray`. """ # sanitize input if not isinstance(a, DNDarray): raise TypeError(f"Expected a to be a DNDarray but was {type(a)}") return a.gshape
[docs] def sort(a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DNDarray] = None): """ Sorts the elements of `a` along the given dimension (by default in ascending order) by their value. The sorting is not stable which means that equal elements in the result may have a different ordering than in the original array. Sorting where `axis==a.split` needs a lot of communication between the processes of MPI. Returns a tuple `(values, indices)` with the sorted local results and the indices of the elements in the original data Parameters ---------- a : DNDarray Input array to be sorted. axis : int, optional The dimension to sort along. Default is the last axis. descending : bool, optional If set to `True`, values are sorted in descending order. out : DNDarray, optional A location in which to store the results. If provided, it must have a broadcastable shape. If not provided or set to `None`, a fresh array is allocated. Raises ------ ValueError If `axis` is not consistent with the available dimensions. Examples -------- >>> x = ht.array([[4, 1], [2, 3]], split=0) >>> x.shape (1, 2) (1, 2) >>> y = ht.sort(x, axis=0) >>> y (array([[2, 1]], array([[1, 0]])) (array([[4, 3]], array([[0, 1]])) >>> ht.sort(x, descending=True) (array([[4, 1]], array([[0, 1]])) (array([[3, 2]], array([[1, 0]])) """ # TODO: Find a better way to ignore specific warnings. This one seems to be related to numpy trying to sort torch tensors. Maybe changing to torch.sort would help? warnings.filterwarnings( "ignore", category=DeprecationWarning, message=r".*__array_wrap__ must accept context and return_scalar arguments.*", ) stride_tricks.sanitize_axis(a.shape, axis) if not a.is_distributed() or axis != a.split: # sorting is not affected by split -> we can just sort along the axis final_result, final_indices = torch.sort(a.larray, dim=axis, descending=descending) else: # sorting is affected by split, processes need to communicate results # transpose so we can work along the 0 axis transposed = a.larray.transpose(axis, 0) local_sorted, local_indices = torch.sort(transposed, dim=0, descending=descending) size = a.comm.Get_size() rank = a.comm.Get_rank() counts, disp, _ = a.comm.counts_displs_shape(a.gshape, axis=axis) actual_indices = local_indices.to(dtype=local_sorted.dtype) + disp[rank] length = local_sorted.size()[0] # Separate the sorted tensor into size + 1 equal length partitions partitions = [x * length // (size + 1) for x in range(1, size + 1)] local_pivots = ( local_sorted[partitions] if counts[rank] else torch.empty((0,) + local_sorted.size()[1:], dtype=local_sorted.dtype) ) # Only processes with elements should share their pivots gather_counts = [int(x > 0) * size for x in counts] gather_displs = (0,) + tuple(np.cumsum(gather_counts[:-1])) pivot_dim = list(transposed.size()) pivot_dim[0] = size * sum(x > 0 for x in counts) # share the local pivots with root process pivot_buffer = torch.empty( pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device ) a.comm.Gatherv(local_pivots, (pivot_buffer, gather_counts, gather_displs), root=0) pivot_dim[0] = size - 1 global_pivots = torch.empty( pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device ) # root process creates new pivots and shares them with other processes if rank == 0: sorted_pivots, _ = torch.sort(pivot_buffer, descending=descending, dim=0) length = sorted_pivots.size()[0] global_partitions = [x * length // size for x in range(1, size)] global_pivots = sorted_pivots[global_partitions] a.comm.Bcast(global_pivots, root=0) lt_partitions = torch.empty((size,) + local_sorted.shape, dtype=torch.int64) last = torch.zeros_like(local_sorted, dtype=torch.int64) comp_op = torch.gt if descending else torch.lt # Iterate over all pivots and store which pivot is the first greater than the elements value for idx, p in enumerate(global_pivots): lt = comp_op(local_sorted, p).int() if idx > 0: lt_partitions[idx] = lt - last else: lt_partitions[idx] = lt last = lt lt_partitions[size - 1] = torch.ones_like(local_sorted, dtype=last.dtype) - last # Matrix holding information how many values will be sent where local_partitions = torch.sum(lt_partitions, dim=1) partition_matrix = torch.empty_like(local_partitions) a.comm.Allreduce(local_partitions, partition_matrix, op=MPI.SUM) # Matrix that holds information which value will be shipped where index_matrix = torch.empty_like(local_sorted, dtype=torch.int64) # Matrix holding information which process get how many values from where shape = (size,) + transposed.size()[1:] send_matrix = torch.zeros(shape, dtype=partition_matrix.dtype) recv_matrix = torch.zeros(shape, dtype=partition_matrix.dtype) for i, x in enumerate(lt_partitions): index_matrix[x > 0] = i send_matrix[i] += torch.sum(x, dim=0) a.comm.Alltoall(send_matrix, recv_matrix) scounts = local_partitions rcounts = recv_matrix shape = (partition_matrix[rank].max(),) + transposed.size()[1:] first_result = torch.empty(shape, dtype=local_sorted.dtype) first_indices = torch.empty_like(first_result) # Iterate through one layer and send values with alltoallv for idx in np.ndindex(local_sorted.shape[1:]): idx_slice: tuple(slice, ...) = (slice(None),) + tuple( slice(ind, ind + 1) for ind in idx ) send_count = scounts[idx_slice].reshape(-1).tolist() send_disp = [0] + list(np.cumsum(send_count[:-1])) s_val = local_sorted[idx_slice].clone() s_ind = actual_indices[idx_slice].clone().to(dtype=local_sorted.dtype) recv_count = rcounts[idx_slice].reshape(-1).tolist() recv_disp = [0] + list(np.cumsum(recv_count[:-1])) rcv_length = rcounts[idx_slice].sum().item() r_val = torch.empty((rcv_length,) + s_val.shape[1:], dtype=local_sorted.dtype) r_ind = torch.empty_like(r_val) a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) first_result[idx_slice][:rcv_length] = r_val first_indices[idx_slice][:rcv_length] = r_ind # The process might not have the correct number of values therefore the tensors need to be rebalanced send_vec = torch.zeros(local_sorted.shape[1:] + (size, size), dtype=torch.int64) target_cumsum = np.cumsum(counts) for idx in np.ndindex(local_sorted.shape[1:]): idx_slice: tuple(slice, ...) = (slice(None),) + tuple( slice(ind, ind + 1) for ind in idx ) current_counts = partition_matrix[idx_slice].reshape(-1).tolist() current_cumsum = list(np.cumsum(current_counts)) for proc in range(size): if current_cumsum[proc] > target_cumsum[proc]: # process has to many values which will be sent to higher ranks first = next(i for i in range(size) if send_vec[idx][:, i].sum() < counts[i]) last = next( i for i in range(size + 1) if i == size or current_cumsum[proc] < target_cumsum[i] ) sent = 0 for i, x in enumerate(counts[first:last]): # Each following process gets as many elements as it needs amount = int(x - send_vec[idx][:, first + i].sum()) send_vec[idx][proc][first + i] = amount current_counts[first + i] += amount sent = send_vec[idx][proc][: first + i + 1].sum().item() if last < size: # Send all left over values to the highest last process amount = partition_matrix[proc][idx] send_vec[idx][proc][last] = int(amount - sent) current_counts[last] += int(amount - sent) elif current_cumsum[proc] < target_cumsum[proc]: # process needs values from higher rank first = ( 0 if proc == 0 else next( i for i, x in enumerate(current_cumsum) if target_cumsum[proc - 1] < x ) ) last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x) for i, x in enumerate(partition_matrix[idx_slice][first:last]): # Taking as many elements as possible from each following process send_vec[idx][first + i][proc] = int(x - send_vec[idx][first + i].sum()) current_counts[first + i] = 0 # Taking just enough elements from the last element to fill the current processes tensor send_vec[idx][last][proc] = int(target_cumsum[proc] - current_cumsum[last - 1]) current_counts[last] -= int(target_cumsum[proc] - current_cumsum[last - 1]) else: # process doesn't need more values send_vec[idx][proc][proc] = ( partition_matrix[proc][idx] - send_vec[idx][proc].sum() ) current_counts[proc] = counts[proc] current_cumsum = list(np.cumsum(current_counts)) # Iterate through one layer again to create the final balanced local tensors second_result = torch.empty_like(local_sorted) second_indices = torch.empty_like(second_result) for idx in np.ndindex(local_sorted.shape[1:]): idx_slice: tuple(slice, ...) = (slice(None),) + tuple( slice(ind, ind + 1) for ind in idx ) send_count = send_vec[idx][rank] send_disp = [0] + list(np.cumsum(send_count[:-1])) recv_count = send_vec[idx][:, rank] recv_disp = [0] + list(np.cumsum(recv_count[:-1])) end = partition_matrix[rank][idx] s_val, indices = first_result[0:end][idx_slice].sort(descending=descending, dim=0) s_ind = first_indices[0:end][idx_slice][indices].reshape_as(s_val) r_val = torch.empty((counts[rank],) + s_val.shape[1:], dtype=local_sorted.dtype) r_ind = torch.empty_like(r_val) a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) second_result[idx_slice] = r_val second_indices[idx_slice] = r_ind second_result, tmp_indices = second_result.sort(dim=0, descending=descending) final_result = second_result.transpose(0, axis) final_indices = torch.empty_like(second_indices) # Update the indices in case the ordering changed during the last sort for idx in np.ndindex(tmp_indices.shape): val = tmp_indices[idx] final_indices[idx] = second_indices[val.item()][idx[1:]] final_indices = final_indices.transpose(0, axis) return_indices = factories.array( final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm ) if out is not None: out.larray = final_result return return_indices else: tensor = factories.array( final_result, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) return tensor, return_indices
[docs] def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DNDarray, ...]: """ Split a DNDarray into multiple sub-DNDarrays. Returns a list of sub-DNDarrays as copies of parts of `x`. Parameters ---------- x : DNDarray DNDArray to be divided into sub-DNDarrays. indices_or_sections : int or 1-dimensional array_like (i.e. undistributed DNDarray, list or tuple) If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along axis. If such a split is not possible, an error is raised. If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along axis the array is split. For example, `indices_or_sections = [2, 3]` would, for `axis = 0`, result in - `x[:2]` - `x[2:3]` - `x[3:]` If an index exceeds the dimension of the array along axis, an empty sub-array is returned correspondingly. axis : int, optional The axis along which to split, default is 0. `axis` is not allowed to equal `x.split` if `x` is distributed. Raises ------ ValueError If `indices_or_sections` is given as integer, but a split does not result in equal division. Warnings -------- Though it is possible to distribute `x`, this function has nothing to do with the split parameter of a DNDarray. See Also -------- :func:`dsplit` :func:`hsplit` :func:`vsplit` Examples -------- >>> x = ht.arange(12).reshape((4, 3)) >>> ht.split(x, 2) [ DNDarray([[0, 1, 2], [3, 4, 5]]), DNDarray([[ 6, 7, 8], [ 9, 10, 11]])] >>> ht.split(x, [2, 3, 5]) [ DNDarray([[0, 1, 2], [3, 4, 5]]), DNDarray([[6, 7, 8]] DNDarray([[ 9, 10, 11]]), DNDarray([])] >>> ht.split(x, [1, 2], 1) [DNDarray([[0], [3], [6], [9]]), DNDarray([[ 1], [ 4], [ 7], [10]], DNDarray([[ 2], [ 5], [ 8], [11]])] """ # sanitize x sanitation.sanitize_in(x) # sanitize axis if not isinstance(axis, int): raise TypeError(f"Expected `axis` to be an integer, but was {type(axis)}") if axis < 0 or axis > len(x.gshape) - 1: raise ValueError( f"Invalid input for `axis`. Valid range is between 0 and {len(x.gshape) - 1}, but was {axis}" ) # sanitize indices_or_sections if isinstance(indices_or_sections, int): if x.gshape[axis] % indices_or_sections != 0: raise ValueError( f"DNDarray with shape {x.gshape} can't be divided equally into {indices_or_sections} chunks along axis {axis}" ) # np to torch mapping - calculate size of resulting data chunks indices_or_sections_t = x.gshape[axis] // indices_or_sections elif isinstance(indices_or_sections, (list, tuple, DNDarray)): if isinstance(indices_or_sections, (list, tuple)): indices_or_sections = factories.array(indices_or_sections) if len(indices_or_sections.gshape) != 1: raise ValueError( f"Expected indices_or_sections to be 1-dimensional, but was {len(indices_or_sections.gshape) - 1}-dimensional instead." ) else: raise TypeError( f"Expected `indices_or_sections` to be array_like (DNDarray, list or tuple), but was {type(indices_or_sections)}" ) # start of actual algorithm if x.split == axis and x.is_distributed(): if isinstance(indices_or_sections, int): # CASE 1 number of processes == indices_or_selections -> split already done due to distribution if x.comm.size == indices_or_sections: new_lshape = list(x.lshape) new_lshape[axis] = 0 sub_arrays_t = [ torch.empty(new_lshape) if i != x.comm.rank else x._DNDarray__array for i in range(indices_or_sections) ] # CASE 2 number of processes != indices_or_selections -> reorder (and split) chunks correctly else: # no data if x.lshape[axis] == 0: sub_arrays_t = [torch.empty(x.lshape) for i in range(indices_or_sections)] else: offset, local_shape, slices = x.comm.chunk(x.gshape, axis) idx_frst_chunk_affctd = offset // indices_or_sections_t left_data_chunk = indices_or_sections_t - (offset % indices_or_sections_t) left_data_process = x.lshape[axis] new_indices = torch.zeros(indices_or_sections, dtype=int) if left_data_chunk >= left_data_process: new_indices[idx_frst_chunk_affctd] = left_data_process else: new_indices[idx_frst_chunk_affctd] = left_data_chunk left_data_process -= left_data_chunk idx_frst_chunk_affctd += 1 # calculate chunks which can be filled completely left_chunks_to_fill = left_data_process // indices_or_sections_t new_indices[ idx_frst_chunk_affctd : (left_chunks_to_fill + idx_frst_chunk_affctd) ] = indices_or_sections_t # assign residual to following process new_indices[left_chunks_to_fill + idx_frst_chunk_affctd] = ( left_data_process % indices_or_sections_t ) sub_arrays_t = torch.split(x._DNDarray__array, new_indices.tolist(), axis) # indices or sections == DNDarray else: if indices_or_sections.split is not None: warnings.warn( f"`indices_or_sections` might not be distributed (along axis {indices_or_sections.split}) " "if `x` is not distributed.\n`indices_or_sections` will be copied with new split axis None." ) indices_or_sections = resplit(indices_or_sections, None) offset, local_shape, slices = x.comm.chunk(x.gshape, axis) slice_axis = slices[axis] # reduce information to the (chunk) relevant indices_or_sections_t = indexing.where( indices_or_sections <= slice_axis.start, slice_axis.start, indices_or_sections ) indices_or_sections_t = indexing.where( indices_or_sections_t >= slice_axis.stop, slice_axis.stop, indices_or_sections_t ) # np to torch mapping # 2. add first and last value to DNDarray # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes indices_or_sections_t = arithmetics.diff( indices_or_sections_t, prepend=slice_axis.start, append=slice_axis.stop ) indices_or_sections_t = factories.array( indices_or_sections_t, dtype=types.int64, is_split=indices_or_sections_t.split, comm=indices_or_sections_t.comm, device=indices_or_sections_t.device, ) # 4. transform the result into a list (torch requirement) indices_or_sections_t = indices_or_sections_t.tolist() sub_arrays_t = torch.split(x._DNDarray__array, indices_or_sections_t, axis) else: if isinstance(indices_or_sections, int): sub_arrays_t = torch.split(x._DNDarray__array, indices_or_sections_t, axis) else: if indices_or_sections.split is not None: warnings.warn( "`indices_or_sections` might not be distributed (along axis {}) if `x` is not distributed.\n" "`indices_or_sections` will be copied with new split axis None.".format( indices_or_sections.split ) ) indices_or_sections = resplit(indices_or_sections, None) # np to torch mapping # 1. replace all values out of range with gshape[axis] to generate size 0 indices_or_sections_t = indexing.where( indices_or_sections <= x.gshape[axis], indices_or_sections, x.gshape[axis] ) # 2. add first and last value to DNDarray # 3. calculate the 1-st discrete difference therefore corresponding chunk sizes indices_or_sections_t = arithmetics.diff( indices_or_sections_t, prepend=0, append=x.gshape[axis] ) indices_or_sections_t = factories.array( indices_or_sections_t, dtype=types.int64, is_split=indices_or_sections_t.split, comm=indices_or_sections_t.comm, device=indices_or_sections_t.device, ) # 4. transform the result into a list (torch requirement) indices_or_sections_t = indices_or_sections_t.tolist() sub_arrays_t = torch.split(x._DNDarray__array, indices_or_sections_t, axis) sub_arrays_ht = [ factories.array(sub_DNDarray, dtype=x.dtype, is_split=x.split, device=x.device, comm=x.comm) for sub_DNDarray in sub_arrays_t ] for sub_DNDarray in sub_arrays_ht: sub_DNDarray.balance_() return sub_arrays_ht
[docs] def squeeze(x: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: """ Remove single-element entries from the shape of a `DNDarray`. Returns the input array, but with all or a subset (indicated by `axis`) of the dimensions of length 1 removed. Split semantics: see Notes below. Parameters ---------- x : DNDarray Input data. axis : None or int or Tuple[int,...], optional Selects a subset of the single-element entries in the shape. If axis is `None`, all single-element entries will be removed from the shape. Raises ------ `ValueError`, if an axis is selected with shape entry greater than one. Notes ----- Split semantics: a distributed DNDarray will keep its original split dimension after "squeezing", which, depending on the squeeze axis, may result in a lower numerical `split` value (see Examples). Examples -------- >>> import heat as ht >>> a = ht.random.randn(1, 3, 1, 5) >>> a DNDarray([[[[-0.2604, 1.3512, 0.1175, 0.4197, 1.3590]], [[-0.2777, -1.1029, 0.0697, -1.3074, -1.1931]], [[-0.4512, -1.2348, -1.1479, -0.0242, 0.4050]]]], dtype=ht.float32, device=cpu:0, split=None) >>> a.shape (1, 3, 1, 5) >>> ht.squeeze(a).shape (3, 5) >>> ht.squeeze(a) DNDarray([[-0.2604, 1.3512, 0.1175, 0.4197, 1.3590], [-0.2777, -1.1029, 0.0697, -1.3074, -1.1931], [-0.4512, -1.2348, -1.1479, -0.0242, 0.4050]], dtype=ht.float32, device=cpu:0, split=None) >>> ht.squeeze(a, axis=0).shape (3, 1, 5) >>> ht.squeeze(a, axis=-2).shape (1, 3, 5) >>> ht.squeeze(a, axis=1).shape Traceback (most recent call last): ... ValueError: Dimension along axis 1 is not 1 for shape (1, 3, 1, 5) >>> x.shape (10, 1, 12, 13) >>> x.split 2 >>> x.squeeze().shape (10, 12, 13) >>> x.squeeze().split 1 """ # Sanitize input sanitation.sanitize_in(x) # Sanitize axis axis = stride_tricks.sanitize_axis(x.shape, axis) if axis is not None: if isinstance(axis, int): dim_is_one = x.shape[axis] == 1 axis = (axis,) elif isinstance(axis, tuple): dim_is_one = bool(torch.tensor(list(x.shape[dim] == 1 for dim in axis)).all()) if not dim_is_one: raise ValueError(f"Dimension along axis {axis} is not 1 for shape {x.shape}") if axis is None: axis = tuple(i for i, dim in enumerate(x.shape) if dim == 1) if x.split is not None and x.split in axis: # split dimension is about to disappear, set split to None x.resplit_(axis=None) out_lshape = tuple(x.lshape[dim] for dim in range(x.ndim) if dim not in axis) out_gshape = tuple(x.gshape[dim] for dim in range(x.ndim) if dim not in axis) x_lsqueezed = x.larray.reshape(out_lshape) # Calculate new split axis according to squeezed shape if x.split is not None: split = x.split - len([dim for dim in axis if dim < x.split]) else: split = None return DNDarray( x_lsqueezed, out_gshape, x.dtype, split=split, device=x.device, comm=x.comm, balanced=x.balanced, )
DNDarray.squeeze: Callable[[DNDarray, Union[int, Tuple[int, ...]]], DNDarray] = ( lambda self, axis=None: squeeze(self, axis) ) DNDarray.squeeze.__doc__ = squeeze.__doc__
[docs] def stack( arrays: Sequence[DNDarray, ...], axis: int = 0, out: Optional[DNDarray] = None ) -> DNDarray: """ Join a sequence of `DNDarray`s along a new axis. The `axis` parameter specifies the index of the new axis in the dimensions of the result. For example, if `axis=0`, the arrays will be stacked along the first dimension; if `axis=-1`, they will be stacked along the last dimension. See Notes below for split semantics. Parameters ---------- arrays : Sequence[DNDarrays, ...] Each DNDarray must have the same shape, must be split along the same axis, and must be balanced. axis : int, optional The axis in the result array along which the input arrays are stacked. out : DNDarray, optional If provided, the destination to place the result. The shape and split axis must be correct, matching that of what stack would have returned if no out argument were specified (see Notes below). Raises ------ TypeError If arrays in sequence are not `DNDarray`s, or if their `dtype` attribute does not match. ValueError If `arrays` contains less than 2 `DNDarray`s. ValueError If the `DNDarray`s are of different shapes, or if they are split along different axes (`split` attribute). RuntimeError If the `DNDarrays` reside on different devices. Notes ----- Split semantics: :func:`stack` requires that all arrays in the sequence be split along the same dimension. After stacking, the data are still distributed along the original dimension, however a new dimension has been added at `axis`, therefore: - if :math:`axis <= split`, output will be distributed along :math:`split+1` - if :math:`axis > split`, output will be distributed along `split` See Also -------- :func:`column_stack` :func:`concatenate` :func:`hstack` :func:`row_stack` :func:`vstack` Examples -------- >>> a = ht.arange(20).reshape((4, 5)) >>> b = ht.arange(20, 40).reshape((4, 5)) >>> ht.stack((a, b), axis=0).larray tensor([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]], [[20, 21, 22, 23, 24], [25, 26, 27, 28, 29], [30, 31, 32, 33, 34], [35, 36, 37, 38, 39]]]) >>> # distributed DNDarrays, 3 processes, stack along last dimension >>> a = ht.arange(20, split=0).reshape(4, 5) >>> b = ht.arange(20, 40, split=0).reshape(4, 5) >>> ht.stack((a, b), axis=-1).larray [0/2] tensor([[[ 0, 20], [0/2] [ 1, 21], [0/2] [ 2, 22], [0/2] [ 3, 23], [0/2] [ 4, 24]], [0/2] [[ 5, 25], [0/2] [ 6, 26], [0/2] [ 7, 27], [0/2] [ 8, 28], [0/2] [ 9, 29]]]) [1/2] tensor([[[10, 30], [1/2] [11, 31], [1/2] [12, 32], [1/2] [13, 33], [1/2] [14, 34]]]) [2/2] tensor([[[15, 35], [2/2] [16, 36], [2/2] [17, 37], [2/2] [18, 38], [2/2] [19, 39]]]) """ arrays = sanitation.sanitize_sequence(arrays) if len(arrays) < 2: raise ValueError("stack expects a sequence of at least 2 DNDarrays") target = arrays[0] try: arrays = sanitation.sanitize_distribution( *arrays, target=target ) # also checks target again except NotImplementedError as e: # transform split axis error to ValueError raise ValueError(e) from e # extract torch tensors t_arrays = [array.larray for array in arrays] # output shape and split axis = stride_tricks.sanitize_axis(target.gshape + (len(arrays),), axis) stacked_shape = target.gshape[:axis] + (len(arrays),) + target.gshape[axis:] if target.split is not None: stacked_split = target.split + 1 if axis <= target.split else target.split else: stacked_split = None # stack locally try: t_stacked = torch.stack(t_arrays, dim=axis) result_dtype = types.canonical_heat_type(t_stacked.dtype) except Exception as e: if "size" in e.args[0] or "shape" in e.args[0]: raise ValueError(e) from e raise e # return stacked DNDarrays if out is not None: sanitation.sanitize_out(out, stacked_shape, stacked_split, target.device) out.larray = t_stacked.type(out.larray.dtype) return out stacked = DNDarray( t_stacked, gshape=stacked_shape, dtype=result_dtype, split=stacked_split, device=target.device, comm=target.comm, balanced=target.balanced, ) return stacked
[docs] def swapaxes(x: DNDarray, axis1: int, axis2: int) -> DNDarray: """ Interchanges two axes of an array. Parameters ---------- x : DNDarray Input array. axis1 : int First axis. axis2 : int Second axis. See Also -------- :func:`~heat.core.linalg.basics.transpose` Permute the dimensions of an array. Examples -------- >>> x = ht.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) >>> ht.swapaxes(x, 0, 1) DNDarray([[[0, 1], [4, 5]], [[2, 3], [6, 7]]], dtype=ht.int64, device=cpu:0, split=None) >>> ht.swapaxes(x, 0, 2) DNDarray([[[0, 4], [2, 6]], [[1, 5], [3, 7]]], dtype=ht.int64, device=cpu:0, split=None) """ axes = list(range(x.ndim)) try: axes[axis1], axes[axis2] = axes[axis2], axes[axis1] except TypeError: raise TypeError( f"'axis1' and 'axis2' must be of type int, found {type(axis1)} and {type(axis2)}" ) return linalg.transpose(x, axes)
DNDarray.swapaxes = lambda self, axis1, axis2: swapaxes(self, axis1, axis2) DNDarray.swapaxes.__doc__ = swapaxes.__doc__
[docs] def unique( a: DNDarray, sorted: bool = False, return_inverse: bool = False, axis: int = None ) -> Tuple[DNDarray, DNDarray]: """ Finds and returns the unique elements of a `DNDarray`. If return_inverse is `True`, the second tensor will hold the list of inverse indices If distributed, it is most efficient if `axis!=a.split`. Parameters ---------- a : DNDarray Input array. sorted : bool, optional Whether the found elements should be sorted before returning as output. Warning: sorted is not working if `axis!=None and axis!=a.split` return_inverse : bool, optional Whether to also return the indices for where elements in the original input ended up in the returned unique list. axis : int, optional Axis along which unique elements should be found. Default to `None`, which will return a one dimensional list of unique values. Examples -------- >>> x = ht.array([[3, 2], [1, 3]]) >>> ht.unique(x, sorted=True) array([1, 2, 3]) >>> ht.unique(x, sorted=True, axis=0) array([[1, 3], [2, 3]]) >>> ht.unique(x, sorted=True, axis=1) array([[2, 3], [3, 1]]) """ if not a.is_distributed(): torch_output = torch.unique( a.larray, sorted=sorted, return_inverse=return_inverse, dim=axis ) if isinstance(torch_output, tuple): heat_output = tuple( factories.array( i, dtype=types.canonical_heat_type(i.dtype), split=None, device=a.device, comm=a.comm, ) for i in torch_output ) else: heat_output = factories.array( torch_output, dtype=a.dtype, split=None, device=a.device, comm=a.comm ) return heat_output local_data = a.larray unique_axis = None inverse_indices = None if axis is not None: # transpose so we can work along the 0 axis local_data = local_data.transpose(0, axis) unique_axis = 0 # Calculate the unique on the local values if a.lshape[a.split] == 0: # Passing an empty vector to torch throws exception if axis is None: res_shape = [0] inv_shape = list(a.gshape) inv_shape[a.split] = 0 else: res_shape = list(local_data.shape) res_shape[0] = 0 inv_shape = [0] lres = torch.empty(res_shape, dtype=a.dtype.torch_type()) inverse_pos = torch.empty(inv_shape, dtype=torch.int64) else: lres, inverse_pos = torch.unique( local_data, sorted=sorted, return_inverse=True, dim=unique_axis ) # Share and gather the results with the other processes uniques = torch.tensor([lres.shape[0]]).to(torch.int32) uniques_buf = torch.empty((a.comm.Get_size(),), dtype=torch.int32) a.comm.Allgather(uniques, uniques_buf) if axis is None or axis == a.split: is_split = None split = a.split output_dim = list(lres.shape) output_dim[0] = uniques_buf.sum().item() # Gather all unique vectors counts = list(uniques_buf.tolist()) displs = list([0] + uniques_buf.cumsum(0).tolist()[:-1]) gres_buf = torch.empty(output_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device) a.comm.Allgatherv(lres, (gres_buf, counts, displs), recv_axis=0) if return_inverse: # Prepare some information to generated the inverse indices list avg_len = a.gshape[a.split] // a.comm.Get_size() rem = a.gshape[a.split] % a.comm.Get_size() # Share the local reverse indices with other processes counts = [avg_len] * a.comm.Get_size() add_vec = [1] * rem + [0] * (a.comm.Get_size() - rem) inverse_counts = [sum(x) for x in zip(counts, add_vec)] inverse_displs = [0] + list(np.cumsum(inverse_counts[:-1])) inverse_dim = list(inverse_pos.shape) inverse_dim[a.split] = a.gshape[a.split] inverse_buf = torch.empty(inverse_dim, dtype=inverse_pos.dtype) # Transpose data and buffer so we can use Allgatherv along axis=0 (axis=1 does not work properly yet) inverse_pos = inverse_pos.transpose(0, a.split) inverse_buf = inverse_buf.transpose(0, a.split) a.comm.Allgatherv( inverse_pos, (inverse_buf, inverse_counts, inverse_displs), recv_axis=0 ) inverse_buf = inverse_buf.transpose(0, a.split) # Run unique a second time gres = torch.unique(gres_buf, sorted=sorted, return_inverse=return_inverse, dim=unique_axis) if return_inverse: # Use the previously gathered information to generate global inverse_indices g_inverse = gres[1] gres = gres[0] if axis is None: # Calculate how many elements we have in each layer along the split axis elements_per_layer = 1 for num, val in enumerate(a.gshape): if not num == a.split: elements_per_layer *= val # Create the displacements for the flattened inverse indices array local_elements = [displ * elements_per_layer for displ in inverse_displs][1:] + [ float("inf") ] # Flatten the inverse indices array every element can be updated to represent a global index transposed = inverse_buf.transpose(0, a.split) transposed_shape = transposed.shape flatten_inverse = transposed.flatten() # Update the index elements iteratively cur_displ = 0 inverse_indices = [0] * len(flatten_inverse) for num in range(len(inverse_indices)): if num >= local_elements[cur_displ]: cur_displ += 1 index = flatten_inverse[num] + displs[cur_displ] inverse_indices[num] = g_inverse[index].tolist() # Convert the flattened array back to the correct global shape of a inverse_indices = torch.tensor(inverse_indices).reshape(transposed_shape) inverse_indices = inverse_indices.transpose(0, a.split) else: inverse_indices = torch.zeros_like(inverse_buf) steps = displs + [None] # Algorithm that creates the correct list for the reverse_indices for i in range(len(steps) - 1): begin = steps[i] end = steps[i + 1] for num, x in enumerate(inverse_buf[begin:end]): inverse_indices[begin + num] = g_inverse[begin + x] else: # Tensor is already split and does not need to be redistributed afterward split = None is_split = a.split max_uniques, max_pos = uniques_buf.max(0) # find indices of vectors if a.comm.Get_rank() == max_pos.item(): # Get indices of the unique vectors to share with all over processes indices = inverse_pos.reshape(-1).unique() else: indices = torch.empty((max_uniques.item(),), dtype=inverse_pos.dtype) a.comm.Bcast(indices, root=max_pos) gres = local_data[indices.tolist()] inverse_indices = indices if sorted: raise ValueError( "Sorting with axis != split is not supported yet. " "See https://github.com/helmholtz-analytics/heat/issues/363" ) if axis is not None: # transpose matrix back gres = gres.transpose(0, axis) split = split if a.split < len(gres.shape) else None result = factories.array( gres, dtype=a.dtype, device=a.device, comm=a.comm, split=split, is_split=is_split ) if split is not None: result.resplit_(a.split) return_value = result if return_inverse: inverse_indices = factories.array( inverse_indices, dtype=inverse_pos.dtype, device=a.device, comm=a.comm ) return_value = [return_value, inverse_indices] return return_value
DNDarray.unique: Callable[[DNDarray, bool, bool, int], Tuple[DNDarray, torch.tensor]] = ( lambda self, sorted=False, return_inverse=False, axis=None: unique( self, sorted, return_inverse, axis ) ) DNDarray.unique.__doc__ = unique.__doc__
[docs] def unfold(a: DNDarray, axis: int, size: int, step: int = 1): """ Returns a DNDarray which contains all slices of size `size` in the axis `axis`. Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) Parameters ---------- a : DNDarray array to unfold axis : int axis in which unfolding happens size : int the size of each slice that is unfolded, must be greater than 1 step : int the step between each slice, must be at least 1 Example: ``` >>> x = ht.arange(1., 8) >>> x DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e) >>> ht.unfold(x, 0, 2, 1) DNDarray([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.], [6., 7.]], dtype=ht.float32, device=cpu:0, split=None) >>> ht.unfold(x, 0, 2, 2) DNDarray([[1., 2.], [3., 4.], [5., 6.]], dtype=ht.float32, device=cpu:0, split=None) ``` Note --------- You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis. """ if step < 1: raise ValueError("step must be >= 1.") if size <= 1: raise ValueError("size must be > 1.") axis = stride_tricks.sanitize_axis(a.shape, axis) if size > a.shape[axis]: raise ValueError( f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}." ) comm = a.comm dev = a.device tdev = dev.torch_device if a.split is None or comm.size == 1 or a.split != axis: # early out ret = factories.array( a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm ) return ret else: # comm.size > 1 and split axis == unfold axis # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 # ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size) if (size - 1 > a.lshape_map[:, axis]).any(): raise RuntimeError("Chunk-size needs to be at least size - 1.") a.get_halo(size - 1, prev=False) counts, displs = a.counts_displs() displs = torch.tensor(displs, device=tdev) # min local index in unfold axis min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank] if min_index >= a.lshape[axis] or ( comm.rank == comm.size - 1 and min_index + size > a.lshape[axis] ): loc_unfold_shape = list(a.lshape) loc_unfold_shape[axis] = 0 ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev) else: # unfold has local data ret_larray = a.array_with_halos[ axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) ].unfold(axis, size, step) ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm) return ret
[docs] def vsplit(x: DNDarray, indices_or_sections: Iterable) -> List[DNDarray, ...]: """ Split array into multiple sub-DNDNarrays along the 1st axis (vertically/row-wise). Returns a list of sub-DNDarrays as copies of parts of ``x``. Parameters ---------- x : DNDarray DNDArray to be divided into sub-DNDarrays. indices_or_sections : Iterable If `indices_or_sections` is an integer, N, the DNDarray will be divided into N equal DNDarrays along the 1st axis.\n If such a split is not possible, an error is raised.\n If `indices_or_sections` is a 1-D DNDarray of sorted integers, the entries indicate where along the 1st axis the array is split.\n If an index exceeds the dimension of the array along the 1st axis, an empty sub-DNDarray is returned correspondingly.\n Raises ------ ValueError If `indices_or_sections` is given as integer, but a split does not result in equal division. Notes ----- Please refer to the split documentation. :func:`hsplit` is equivalent to split with `axis=0`, the array is always split along the first axis regardless of the array dimension. See Also -------- :func:`split` :func:`dsplit` :func:`hsplit` Examples -------- >>> x = ht.arange(24).reshape((4, 3, 2)) >>> ht.vsplit(x, 2) [DNDarray([[[ 0, 1], [ 2, 3], [ 4, 5]], [[ 6, 7], [ 8, 9], [10, 11]]]), DNDarray([[[12, 13], [14, 15], [16, 17]], [[18, 19], [20, 21], [22, 23]]])] >>> ht.vsplit(x, [1, 3]) [DNDarray([[[0, 1], [2, 3], [4, 5]]]), DNDarray([[[ 6, 7], [ 8, 9], [10, 11]], [[12, 13], [14, 15], [16, 17]]]), DNDarray([[[18, 19], [20, 21], [22, 23]]])] """ return split(x, indices_or_sections, 0)
[docs] def resplit(arr: DNDarray, axis: Optional[int] = None) -> DNDarray: """ Out-of-place redistribution of the content of the `DNDarray`. Allows to "unsplit" (i.e. gather) all values from all nodes, as well as to define a new axis along which the array is split without changes to the values. Parameters ---------- arr : DNDarray The array from which to resplit axis : int or None The new split axis, `None` denotes gathering, an int will set the new split axis Warning ---------- This operation might involve a significant communication overhead. Use it sparingly and preferably for small arrays. Examples -------- >>> a = ht.zeros( ... ( ... 4, ... 5, ... ), ... split=0, ... ) >>> a.lshape (0/2) (2, 5) (1/2) (2, 5) >>> b = resplit(a, None) >>> b.split None >>> b.lshape (0/2) (4, 5) (1/2) (4, 5) >>> a = ht.zeros( ... ( ... 4, ... 5, ... ), ... split=0, ... ) >>> a.lshape (0/2) (2, 5) (1/2) (2, 5) >>> b = resplit(a, 1) >>> b.split 1 >>> b.lshape (0/2) (4, 3) (1/2) (4, 2) """ # sanitize the axis to check whether it is in range axis = stride_tricks.sanitize_axis(arr.shape, axis) # early out for unchanged content if axis == arr.split: return arr.copy() if not arr.is_distributed(): return factories.array(arr.larray, split=axis, device=arr.device, comm=arr.comm, copy=True) if axis is None: # new_arr = arr.copy() gathered = torch.empty( arr.shape, dtype=arr.dtype.torch_type(), device=arr.device.torch_device ) counts, displs = arr.counts_displs() arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split) new_arr = factories.array( gathered, is_split=axis, device=arr.device, comm=arr.comm, dtype=arr.dtype ) return new_arr new_arr = factories.empty(arr.gshape, split=axis, dtype=arr.dtype, device=arr.device) arr_tiles = tiling.SplitTiles(arr) new_tiles = tiling.SplitTiles(new_arr) new_arr.larray = _axis2axisResplit( arr.larray, arr.split, arr_tiles, new_arr.larray, axis, new_tiles, arr.comm ) return new_arr
DNDarray.resplit: Callable[[DNDarray, Optional[int]], DNDarray] = lambda self, axis=None: resplit( self, axis ) DNDarray.resplit.__doc__ = resplit.__doc__ def _axis2axisResplit( source_larray: torch.Tensor, source_split: int, source_tiles: tiling.SplitTiles, target_larray: torch.Tensor, target_split: int, target_tiles: tiling.SplitTiles, comm: Communication, ) -> torch.Tensor: """ Resplits the input array along a new axis and performs data exchange using MPI_Alltoallw, after [1]. Returns `target_larray` object with the data after the exchange. Parameters ---------- source_larray : torch.Tensor The source array to be resplit. source_split : int The axis along which the source array is split. source_tiles : tiling.SplitTiles The tiling object containing the subarray parameters for the source array. target_larray : torch.Tensor The target array to store the resplit data. target_split : int The axis along which the target array is split. target_tiles : tiling.SplitTiles The tiling object containing the subarray parameters for the target array. comm : Communication The communication object for MPI communication. References ---------- [1] Dalcin, Mortensen, Keyes, "Fast parallel multidimensional FFT using advanced MPI", 2018. """ # Create subarray types for original local shapes split along the new axis source_subarray_params = source_tiles.get_subarray_params(source_split, target_split) # Create subarray types for resplit local array along the old axis target_subarray_params = target_tiles.get_subarray_params(target_split, source_split) world_size = comm.Get_size() counts = [1] * world_size displs = [0] * world_size # Perform the data exchange using MPI_Alltoallw comm.Alltoallw( (source_larray, (counts.copy(), displs.copy()), source_subarray_params), (target_larray, (counts.copy(), displs.copy()), target_subarray_params), ) return target_larray DNDarray._axis2axisResplit = ( lambda self, comm, source_larray, source_split, source_tiles, target_larray, target_split, target_tile: ( _axis2axisResplit( comm, source_larray, source_split, source_tiles, target_larray, target_split, target_tile, ) ) ) DNDarray._axis2axisResplit.__doc__ = _axis2axisResplit.__doc__
[docs] def row_stack(arrays: Sequence[DNDarray, ...]) -> DNDarray: """ Stack 1-D or 2-D `DNDarray`s as rows into a 2-D `DNDarray`. If the input arrays are 1-D, they will be stacked as rows. If they are 2-D, they will be concatenated along the first axis. Parameters ---------- arrays : Sequence[DNDarrays, ...] Sequence of `DNDarray`s. Raises ------ ValueError If arrays have more than 2 dimensions Notes ----- All ``DNDarray``s in the sequence must have the same number of columns. All ``DNDarray``s must be split along the same axis! See Also -------- :func:`column_stack` :func:`concatenate` :func:`hstack` :func:`stack` :func:`vstack` Examples -------- >>> # 1-D tensors >>> a = ht.array([1, 2, 3]) >>> b = ht.array([2, 3, 4]) >>> ht.row_stack((a, b)).larray tensor([[1, 2, 3], [2, 3, 4]]) >>> # 1-D and 2-D tensors >>> a = ht.array([1, 2, 3]) >>> b = ht.array([[2, 3, 4], [5, 6, 7]]) >>> c = ht.array([[7, 8, 9], [10, 11, 12]]) >>> ht.row_stack((a, b, c)).larray tensor([[ 1, 2, 3], [ 2, 3, 4], [ 5, 6, 7], [ 7, 8, 9], [10, 11, 12]]) >>> # distributed DNDarrays, 3 processes >>> a = ht.arange(10, split=0).reshape((2, 5)) >>> b = ht.arange(5, 20, split=0).reshape((3, 5)) >>> c = ht.arange(20, 40, split=0).reshape((4, 5)) >>> ht.row_stack((a, b, c)).larray [0/2] tensor([[0, 1, 2, 3, 4], [0/2] [5, 6, 7, 8, 9], [0/2] [5, 6, 7, 8, 9]], dtype=torch.int32) [1/2] tensor([[10, 11, 12, 13, 14], [1/2] [15, 16, 17, 18, 19], [1/2] [20, 21, 22, 23, 24]], dtype=torch.int32) [2/2] tensor([[25, 26, 27, 28, 29], [2/2] [30, 31, 32, 33, 34], [2/2] [35, 36, 37, 38, 39]], dtype=torch.int32) >>> # distributed 1-D and 2-D DNDarrays, 3 processes >>> a = ht.arange(5, split=0) >>> b = ht.arange(5, 20, split=0).reshape((3, 5)) >>> ht.row_stack((a, b)).larray [0/2] tensor([[0, 1, 2, 3, 4], [0/2] [5, 6, 7, 8, 9]]) [1/2] tensor([[10, 11, 12, 13, 14]]) [2/2] tensor([[15, 16, 17, 18, 19]]) """ arr_dims = [array.ndim for array in arrays] # sanitation, arrays can be 1-d or 2-d, see sanitation module #468 over_dims = [i for i, j in enumerate(arr_dims) if j > 2] if len(over_dims) > 0: raise ValueError("Arrays must be 1-D or 2-D") if arr_dims.count(1) == len(arr_dims): # all arrays are 1-D, stack return stack(arrays, axis=0) else: if arr_dims.count(1) > 0: arr_1d = [i for i, j in enumerate(arr_dims) if j == 1] # 1-D arrays must be row arrays arrays = list(arrays) for ind in arr_1d: arrays[ind] = arrays[ind].reshape((1, arrays[ind].size)) return concatenate(arrays, axis=0)
[docs] def vstack(arrays: Sequence[DNDarray, ...]) -> DNDarray: """ Stack arrays in sequence vertically (row wise). This is equivalent to concatenation along the first axis. This function makes most sense for arrays with up to 3 dimensions. For instance, for pixel-data with a height (first axis), width (second axis), and r/g/b channels (third axis). The :func:`concatenate` function provides more general stacking operations. Parameters ---------- arrays : Sequence[DNDarray,...] The arrays must have the same shape along all but the first axis. 1-D arrays must have the same length. Notes ----- The split axis will be switched to 1 in the case that both elements are 1D and split=0 See Also -------- :func:`concatenate` :func:`stack` :func:`hstack` :func:`column_stack` :func:`row_stack` Examples -------- >>> a = ht.array([1, 2, 3]) >>> b = ht.array([2, 3, 4]) >>> ht.vstack((a, b)).larray [0/1] tensor([[1, 2, 3], [0/1] [2, 3, 4]]) [1/1] tensor([[1, 2, 3], [1/1] [2, 3, 4]]) >>> a = ht.array([1, 2, 3], split=0) >>> b = ht.array([2, 3, 4], split=0) >>> ht.vstack((a, b)).larray [0/1] tensor([[1, 2], [0/1] [2, 3]]) [1/1] tensor([[3], [1/1] [4]]) >>> a = ht.array([[1], [2], [3]], split=0) >>> b = ht.array([[2], [3], [4]], split=0) >>> ht.vstack((a, b)).larray [0] tensor([[1], [0] [2], [0] [3]]) [1] tensor([[2], [1] [3], [1] [4]]) """ arrays = list(arrays) for cn, arr in enumerate(arrays): if len(arr.gshape) == 1: arrays[cn] = arr.expand_dims(0).resplit_(arr.split) return concatenate(arrays, axis=0)
[docs] def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: """ Construct a new DNDarray by repeating 'x' the number of times given by 'reps'. If 'reps' has length 'd', the result will have 'max(d, x.ndim)' dimensions: - if 'x.ndim < d', 'x' is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication (if this is not the desired behavior, promote 'x' to d-dimensions manually before calling this function); - if 'x.ndim > d', 'reps' will replicate the last 'd' dimensions of 'x', i.e., if 'x.shape' is (2, 3, 4, 5), a 'reps' of (2, 2) will be expanded to (1, 1, 2, 2). Parameters ---------- x : DNDarray Input reps : Sequence[ints,...] Repetitions Returns ------- tiled : DNDarray Split semantics: if `x` is distributed, the tiled data will be distributed along the same dimension. Note that nominally `tiled.split != x.split` in the case where `len(reps) > x.ndim`. See example below. Examples -------- >>> x = ht.arange(12).reshape((4, 3)).resplit_(0) >>> x DNDarray([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]], dtype=ht.int32, device=cpu:0, split=0) >>> reps = (1, 2, 2) >>> tiled = ht.tile(x, reps) >>> tiled DNDarray([[[ 0, 1, 2, 0, 1, 2], [ 3, 4, 5, 3, 4, 5], [ 6, 7, 8, 6, 7, 8], [ 9, 10, 11, 9, 10, 11], [ 0, 1, 2, 0, 1, 2], [ 3, 4, 5, 3, 4, 5], [ 6, 7, 8, 6, 7, 8], [ 9, 10, 11, 9, 10, 11]]], dtype=ht.int32, device=cpu:0, split=1) """ # x can be DNDarray or scalar try: _ = x.larray except AttributeError: try: _ = x.shape raise TypeError(f"Input can be a DNDarray or a scalar, is {type(x)}") except AttributeError: x = factories.array(x).reshape(1) x_proxy = x.__torch_proxy__() # torch-proof args/kwargs: # torch `reps`: int or sequence of ints; numpy `reps`: can be array-like try: _ = x_proxy.repeat(reps) except TypeError: # `reps` is array-like or contains non-int elements try: reps = resplit(reps, None).tolist() except AttributeError: try: reps = reps.tolist() except AttributeError: try: _ = x_proxy.repeat(reps) except TypeError: raise TypeError( f"reps must be a sequence of ints, got {[type(i) for i in reps]}" ) except RuntimeError: pass except RuntimeError: pass try: reps = list(reps) except TypeError: # scalar to list reps = [reps] # torch reps vs. numpy reps: dimensions if len(reps) != x.ndim: added_dims = abs(len(reps) - x.ndim) if len(reps) > x.ndim: new_shape = added_dims * (1,) + x.gshape new_split = None if x.split is None else x.split + added_dims x = x.reshape(new_shape, new_split=new_split) else: reps = added_dims * [1] + reps out_gshape = tuple(x_proxy.repeat(reps).shape) if not x.is_distributed() or reps[x.split] == 1: # no repeats along the split axis: local operation t_tiled = x.larray.repeat(reps) out_gshape = tuple(x_proxy.repeat(reps).shape) return DNDarray( t_tiled, out_gshape, dtype=x.dtype, split=x.split, device=x.device, comm=x.comm, balanced=x.balanced, ) # repeats along the split axis, work along dim 0 size = x.comm.Get_size() rank = x.comm.Get_rank() trans_axes = list(range(x.ndim)) if x.split != 0: trans_axes[0], trans_axes[x.split] = x.split, 0 reps[0], reps[x.split] = reps[x.split], reps[0] x = linalg.transpose(x, trans_axes) x_proxy = x.__torch_proxy__() out_gshape = tuple(x_proxy.repeat(reps).shape) local_x = x.larray # allocate tiled DNDarray, at first tiled along split axis only split_reps = [rep if i == x.split else 1 for i, rep in enumerate(reps)] split_tiled_shape = tuple(x_proxy.repeat(split_reps).shape) tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=x.split, comm=x.comm) # collect slicing information from all processes. slices_map = [] for array in [x, tiled]: counts, displs = array.counts_displs() t_slices_starts = torch.tensor(displs, device=local_x.device) t_slices_ends = t_slices_starts + torch.tensor(counts, device=local_x.device) slices_map.append([t_slices_starts, t_slices_ends]) t_slices_x, t_slices_tiled = slices_map # keep track of repetitions: # local_x_starts.shape, local_x_ends.shape changing from (size,) to (reps[split], size) reps_indices = [x.gshape[x.split] * rep for rep in (range(reps[x.split]))] t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32, device=local_x.device).reshape( len(reps_indices), 1 ) for i, t in enumerate(t_slices_x): t = t.repeat((reps[x.split], 1)) t += t_reps_indices t_slices_x[i] = t # distribution logic on current rank: distr_map = [] slices_map = [] for i in range(2): if i == 0: # send logic for x slices on rank local_x_starts = t_slices_x[0][:, rank].reshape(reps[x.split], 1) local_x_ends = t_slices_x[1][:, rank].reshape(reps[x.split], 1) t_tiled_starts, t_tiled_ends = t_slices_tiled else: # recv logic for tiled slices on rank local_x_starts, local_x_ends = t_slices_x t_tiled_starts = t_slices_tiled[0][rank] t_tiled_ends = t_slices_tiled[1][rank] t_max_starts = torch.max(local_x_starts, t_tiled_starts) t_min_ends = torch.min(local_x_ends, t_tiled_ends) coords = torch.where(t_min_ends - t_max_starts > 0) # remove repeat offset from slices if sending if i == 0: t_max_starts -= t_reps_indices t_min_ends -= t_reps_indices starts = t_max_starts[coords].unsqueeze_(0) ends = t_min_ends[coords].unsqueeze_(0) slices_map.append(torch.cat((starts, ends), dim=0)) distr_map.append(coords) # bookkeeping in preparation for Alltoallv send_map, recv_map = distr_map send_rep, send_to_ranks = send_map recv_rep, recv_from_ranks = recv_map send_slices, recv_slices = slices_map # do not assume that `x` is balanced _, displs = x.counts_displs() offset_x = displs[rank] # impose load-balance on output offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) t_tiled = tiled.larray active_send_counts = send_slices.clone() active_send_counts[0] *= -1 active_send_counts = active_send_counts.sum(0) active_recv_counts = recv_slices.clone() active_recv_counts[0] *= -1 active_recv_counts = active_recv_counts.sum(0) send_slices -= offset_x recv_slices -= offset_tiled recv_buf = t_tiled.clone() # we need as many Alltoallv calls as repeats along the split axis for rep in range(reps[x.split]): # send_data, send_counts, send_displs on rank all_send_counts = [0] * size all_send_displs = [0] * size send_this_rep = torch.where(send_rep == rep)[0].tolist() dest_this_rep = send_to_ranks[send_this_rep].tolist() for i, j in zip(send_this_rep, dest_this_rep): all_send_counts[j] = active_send_counts[i].item() all_send_displs[j] = send_slices[0][i].item() local_send_slice = [slice(None)] * x.ndim local_send_slice[x.split] = slice( all_send_displs[0], all_send_displs[0] + sum(all_send_counts) ) send_buf = local_x[local_send_slice].clone() # recv_data, recv_counts, recv_displs on rank all_recv_counts = [0] * size all_recv_displs = [0] * size recv_this_rep = torch.where(recv_rep == rep)[0].tolist() orig_this_rep = recv_from_ranks[recv_this_rep].tolist() for i, j in zip(recv_this_rep, orig_this_rep): all_recv_counts[j] = active_recv_counts[i].item() all_recv_displs[j] = recv_slices[0][i].item() local_recv_slice = [slice(None)] * x.ndim local_recv_slice[x.split] = slice( all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts) ) x.comm.Alltoallv( (send_buf, all_send_counts, all_send_displs), (recv_buf, all_recv_counts, all_recv_displs), ) t_tiled[local_recv_slice] = recv_buf[local_recv_slice] # finally tile along non-split axes if needed reps[x.split] = 1 tiled = DNDarray( t_tiled.repeat(reps), out_gshape, dtype=x.dtype, split=x.split, device=x.device, comm=x.comm, balanced=True, ) if trans_axes != list(range(x.ndim)): # transpose back to original shape x = linalg.transpose(x, trans_axes) tiled = linalg.transpose(tiled, trans_axes) return tiled
[docs] def topk( a: DNDarray, k: int, dim: int = -1, largest: bool = True, sorted: bool = True, out: Optional[Tuple[DNDarray, DNDarray]] = None, ) -> Tuple[DNDarray, DNDarray]: """ Returns the :math:`k` highest entries in the array. (Not Stable for split arrays) Parameters ---------- a: DNDarray Input data k: int Desired number of output items dim: int, optional Dimension along which to sort, per default the last dimension largest: bool, optional If `True`, return the :math:`k` largest items, otherwise return the :math:`k` smallest items sorted: bool, optional Whether to sort the output (descending if `largest` is `True`, else ascending) out: Tuple[DNDarray, ...], optional output buffer Examples -------- >>> a = ht.array([1, 2, 3]) >>> ht.topk(a, 2) (DNDarray([3, 2], dtype=ht.int64, device=cpu:0, split=None), DNDarray([2, 1], dtype=ht.int64, device=cpu:0, split=None)) >>> a = ht.array([[1, 2, 3], [1, 2, 3]]) >>> ht.topk(a, 2, dim=1) (DNDarray([[3, 2], [3, 2]], dtype=ht.int64, device=cpu:0, split=None), DNDarray([[2, 1], [2, 1]], dtype=ht.int64, device=cpu:0, split=None)) >>> a = ht.array([[1, 2, 3], [1, 2, 3]], split=1) >>> ht.topk(a, 2, dim=1) (DNDarray([[3, 2], [3, 2]], dtype=ht.int64, device=cpu:0, split=1), DNDarray([[2, 1], [2, 1]], dtype=ht.int64, device=cpu:0, split=1)) """ if out is not None: if out[0].dtype != a.dtype: raise RuntimeError( "dtypes of 'out[0]' and 'a' do not match, found {} != {}".format( out[0].dtype, a.dtype ) ) if out[1].dtype != types.int64: raise RuntimeError(f"dtype of 'out[1]' is not ht.int64, found {out[1].dtype}") dim = stride_tricks.sanitize_axis(a.gshape, dim) neutral_value = sanitation.sanitize_infinity(a) if largest: neutral_value = -neutral_value def local_topk(*args, **kwargs): shape = a.lshape if shape[dim] < k: result, indices = torch.topk(args[0], shape[dim], largest=largest, sorted=sorted) if dim == a.split: # Pad the result with neutral values to fill the buffer size = list(result.shape) padding_sizes = [ k - size[dim] if index == dim else 0 for index, item in enumerate(list(result.shape)) ] padding = torch.nn.ConstantPad1d(padding_sizes, neutral_value) result = padding(result) # Different value for indices padding to prevent type casting issues padding = torch.nn.ConstantPad1d(padding_sizes, 0) indices = padding(indices) else: result, indices = torch.topk(args[0], k=k, dim=dim, largest=largest, sorted=sorted) # add offset of data chunks if reduction is computed across split axis if dim == a.split: offset, _, _ = a.comm.chunk(shape, a.split) indices = indices.clone() indices += torch.tensor( offset * a.comm.rank, dtype=indices.dtype, device=indices.device ) local_shape = list(result.shape) local_shape_len = len(shape) metadata = torch.tensor( [k, dim, largest, sorted, local_shape_len, *local_shape], device=indices.device ) if result.is_mps: # MPS does not support double precision send_buffer = torch.cat( (metadata.float(), result.float().flatten(), indices.flatten().float()) ) else: send_buffer = torch.cat( (metadata.double(), result.double().flatten(), indices.flatten().double()) ) return send_buffer gres = _operations.__reduce_op( a, local_topk, MPI_TOPK, axis=dim, neutral=neutral_value, dim=dim, sorted=sorted, largest=largest, ) # Split data again to return a tuple local_result = gres.larray shape_len = int(local_result[4]) gres, gindices = local_result[5 + shape_len :].chunk(2) gres = gres.reshape(*local_result[5 : 5 + shape_len].int()) gindices = gindices.reshape(*local_result[5 : 5 + shape_len].int()) # Create output with correct split if dim == a.split: is_split = None split = a.split else: is_split = a.split split = None final_array = factories.array( gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split ) final_indices = factories.array( gindices, dtype=types.int64, device=a.device, comm=a.comm, split=split, is_split=is_split ) if out is not None: if out[0].shape != final_array.shape or out[1].shape != final_indices.shape: raise ValueError( "Expecting output buffer tuple of shape ({}, {}), got ({}, {})".format( gres.shape, gindices.shape, out[0].shape, out[1].shape ) ) try: out[0].larray.untyped_storage().copy_(final_array.larray.untyped_storage()) out[1].larray.untyped_storage().copy_(final_indices.larray.untyped_storage()) except AttributeError: out[0].larray.storage().copy_(final_array.larray.storage()) out[1].larray.storage().copy_(final_indices.larray.storage()) out[0]._DNDarray__dtype = a.dtype out[1]._DNDarray__dtype = types.int64 return final_array, final_indices
def mpi_topk(a, b, mpi_type): """ MPI function for distributed :func:`topk` """ # Parse Buffer a_parsed = torch.from_numpy(np.frombuffer(a, dtype=np.float64)) b_parsed = torch.from_numpy(np.frombuffer(b, dtype=np.float64)) # Collect metadata from Buffer k = int(a_parsed[0].item()) dim = int(a_parsed[1].item()) largest = bool(a_parsed[2].item()) sorted = bool(a_parsed[3].item()) # Offset is the length of the shape on the buffer len_shape_a = int(a_parsed[4]) shape_a = a_parsed[5 : 5 + len_shape_a].int().tolist() len_shape_b = int(b_parsed[4]) shape_b = b_parsed[5 : 5 + len_shape_b].int().tolist() # separate the data into values, indices a_values, a_indices = a_parsed[len_shape_a + 5 :].chunk(2) b_values, b_indices = b_parsed[len_shape_b + 5 :].chunk(2) # reconstruct the flattened data by shape a_values = a_values.reshape(shape_a) a_indices = a_indices.reshape(shape_a) b_values = b_values.reshape(shape_b) b_indices = b_indices.reshape(shape_b) # concatenate the data to actually run topk on values = torch.cat((a_values, b_values), dim=dim) indices = torch.cat((a_indices, b_indices), dim=dim) result, k_indices = torch.topk(values, k, dim=dim, largest=largest, sorted=sorted) indices = torch.gather(indices, dim, k_indices) metadata = a_parsed[0 : len_shape_a + 5] final_result = torch.cat((metadata, result.double().flatten(), indices.double().flatten())) b_parsed.copy_(final_result) MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True)