Source code for heat.core.dndarray

"""Provides HeAT's core data structure, the DNDarray, a distributed n-dimensional array"""

from __future__ import annotations

import math
import numpy as np
import torch
import warnings

from inspect import stack
from mpi4py import MPI
from pathlib import Path
from typing import List, Union, Tuple, TypeVar, Optional

warnings.simplefilter("always", ResourceWarning)

# NOTE: heat module imports need to be placed at the very end of the file to avoid cyclic dependencies
__all__ = ["DNDarray"]

Communication = TypeVar("Communication")


class LocalIndex:
    """
    Indexing class for local operations (primarily for :func:`lloc` function)
    For docs on ``__getitem__`` and ``__setitem__`` see :func:`lloc`
    """

    def __init__(self, obj):
        self.obj = obj

    def __getitem__(self, key):
        return self.obj[key]

    def __setitem__(self, key, value):
        self.obj[key] = value


[docs] class DNDarray: """ Distributed N-Dimensional array. The core element of HeAT. It is composed of PyTorch tensors local to each process. Parameters ---------- array : torch.Tensor Local array elements gshape : Tuple[int,...] The global shape of the array dtype : datatype The datatype of the array split : int or None The axis on which the array is divided between processes device : Device The device on which the local arrays are using (cpu or gpu) comm : Communication The communications object for sending and receiving data balanced: bool or None Describes whether the data are evenly distributed across processes. If this information is not available (``self.balanced is None``), it can be gathered via the :func:`is_balanced()` method (requires communication). """ def __init__( self, array: torch.Tensor, gshape: Tuple[int, ...], dtype: datatype, split: Union[int, None], device: Device, comm: Communication, balanced: bool, ): self.__array = array self.__gshape = gshape self.__dtype = dtype self.__split = split self.__device = device self.__comm = comm self.__balanced = balanced self.__ishalo = False self.__halo_next = None self.__halo_prev = None self.__partitions_dict__ = None self.__lshape_map = None # check for inconsistencies between torch and heat devices assert str(array.device) == device.torch_device @property def balanced(self) -> bool: """ Boolean value indicating if the DNDarray is balanced between the MPI processes """ return self.__balanced @property def comm(self) -> Communication: """ The :class:`~heat.core.communication.Communication` of the ``DNDarray`` """ return self.__comm @property def device(self) -> Device: """ The :class:`~heat.core.devices.Device` of the ``DNDarray`` """ return self.__device @property def dtype(self) -> datatype: """ The :class:`~heat.core.types.datatype` of the ``DNDarray`` """ return self.__dtype @property def gshape(self) -> Tuple: """ Returns the global shape of the ``DNDarray`` across all processes """ return self.__gshape @property def halo_next(self) -> torch.Tensor: """ Returns the halo of the next process """ return self.__halo_next @property def halo_prev(self) -> torch.Tensor: """ Returns the halo of the previous process """ return self.__halo_prev @property def larray(self) -> torch.Tensor: """ Returns the underlying process-local ``torch.Tensor`` of the ``DNDarray`` """ return self.__array @larray.setter def larray(self, array: torch.Tensor): """ Setter for ``self.larray``, the underlying local ``torch.Tensor`` of the ``DNDarray``. Parameters ---------- array : torch.Tensor The new underlying local ``torch.tensor`` of the ``DNDarray`` Warning ----------- Please use this function with care, as it might corrupt/invalidate the metadata in the ``DNDarray`` instance. """ # sanitize tensor input sanitation.sanitize_in_tensor(array) # verify consistency of tensor shape with global DNDarray sanitation.sanitize_lshape(self, array) # set balanced status split = self.split if split is not None and array.shape[split] != self.lshape[split]: self.__balanced = None self.__array = array @property def nbytes(self) -> int: """ Returns the number of bytes consumed by the global tensor. Equivalent to property gnbytes. Note ------------ Does not include memory consumed by non-element attributes of the ``DNDarray`` object. """ return self.__array.element_size() * self.size @property def ndim(self) -> int: """ Number of dimensions of the ``DNDarray`` """ return len(self.__gshape) @property def __partitioned__(self) -> dict: """ Return a dictionary containing information useful for working with the partitioned data. These items include the shape of the data on each process, the starting index of the data that a process has, the datatype of the data, the local devices, as well as the global partitioning scheme. An example of the output and shape is shown in :func:`ht.core.DNDarray.create_partition_interface <ht.core.DNDarray.create_partition_interface>`. Returns ------- dictionary with the partition interface """ if self.__partitions_dict__ is None: self.__partitions_dict__ = self.create_partition_interface() return self.__partitions_dict__ @property def size(self) -> int: """ Number of total elements of the ``DNDarray`` """ if self.larray.is_mps: # MPS does not support double precision size = torch.prod( torch.tensor(self.gshape, dtype=torch.float32, device=self.device.torch_device) ) else: size = torch.prod( torch.tensor(self.gshape, dtype=torch.float64, device=self.device.torch_device) ) return size.long().item() @property def gnbytes(self) -> int: """ Returns the number of bytes consumed by the global ``DNDarray`` Note ----------- Does not include memory consumed by non-element attributes of the ``DNDarray`` object. """ return self.nbytes @property def gnumel(self) -> int: """ Returns the number of total elements of the ``DNDarray`` """ return self.size @property def imag(self) -> DNDarray: """ Return the imaginary part of the ``DNDarray``. """ return complex_math.imag(self) @property def lnbytes(self) -> int: """ Returns the number of bytes consumed by the local ``torch.Tensor`` Note ------------------- Does not include memory consumed by non-element attributes of the ``DNDarray`` object. """ return self.__array.element_size() * self.__array.nelement() @property def lnumel(self) -> int: """ Number of elements of the ``DNDarray`` on each process """ return np.prod(self.__array.shape) @property def lloc(self) -> Union[DNDarray, None]: """ Local item setter and getter. i.e. this function operates on a local level and only on the PyTorch tensors composing the :class:`DNDarray`. This function uses the LocalIndex class. As getter, it returns a ``DNDarray`` with the indices selected at a *local* level Parameters ---------- key : int or slice or Tuple[int,...] Indices of the desired data. value : scalar, optional All types compatible with pytorch tensors, if none given then this is a getter function Examples -------- >>> a = ht.zeros((4, 5), split=0) DNDarray([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], dtype=ht.float32, device=cpu:0, split=0) >>> a.lloc[1, 0:4] (1/2) tensor([0., 0., 0., 0.]) (2/2) tensor([0., 0., 0., 0.]) >>> a.lloc[1, 0:4] = torch.arange(1, 5) >>> a DNDarray([[0., 0., 0., 0., 0.], [1., 2., 3., 4., 0.], [0., 0., 0., 0., 0.], [1., 2., 3., 4., 0.]], dtype=ht.float32, device=cpu:0, split=0) """ return LocalIndex(self.__array) @property def lshape(self) -> Tuple[int]: """ Returns the shape of the ``DNDarray`` on each node """ return tuple(self.__array.shape) @property def lshape_map(self) -> torch.Tensor: """ Returns the lshape map. If it hasn't been previously created then it will be created here. """ return self.create_lshape_map() @property def real(self) -> DNDarray: """ Return the real part of the ``DNDarray``. """ return complex_math.real(self) @property def shape(self) -> Tuple[int]: """ Returns the shape of the ``DNDarray`` as a whole """ return self.__gshape @property def split(self) -> int: """ Returns the axis on which the ``DNDarray`` is split """ return self.__split @property def stride(self) -> Tuple[int]: """ Returns the steps in each dimension when traversing a ``DNDarray``. torch-like usage: ``self.stride()`` """ return self.__array.stride @property def strides(self) -> Tuple[int]: """ Returns bytes to step in each dimension when traversing a ``DNDarray``. numpy-like usage: ``self.strides()`` """ steps = list(self.larray.stride()) try: itemsize = self.larray.untyped_storage().element_size() except AttributeError: itemsize = self.larray.storage().element_size() strides = tuple(step * itemsize for step in steps) return strides @property def T(self): """ Reverse the dimensions of a DNDarray. """ # specialty docs for this version of transpose. The transpose function is in heat/core/linalg/basics return linalg.transpose(self, axes=None) @property def array_with_halos(self) -> torch.Tensor: """ Fetch halos of size ``halo_size`` from neighboring ranks and save them in ``self.halo_next``/``self.halo_prev`` in case they are not already stored. If ``halo_size`` differs from the size of already stored halos, the are overwritten. """ return self.__cat_halo() def __prephalo(self, start, end) -> torch.Tensor: """ Extracts the halo indexed by start, end from ``self.array`` in the direction of ``self.split`` Parameters ---------- start : int Start index of the halo extracted from ``self.array`` end : int End index of the halo extracted from ``self.array`` """ ix = [slice(None, None, None)] * len(self.shape) try: ix[self.split] = slice(start, end) except IndexError: print("Indices out of bound") return self.__array[tuple(ix)].clone()
[docs] def get_halo(self, halo_size: int, prev: bool = True, next: bool = True): """ Fetch halos of size ``halo_size`` from neighboring ranks and save them in ``self.halo_next/self.halo_prev``. Parameters ---------- halo_size : int Size of the halo. prev : bool, optional If True, fetch the halo from the previous rank. Default: True. next : bool, optional If True, fetch the halo from the next rank. Default: True. """ if not isinstance(halo_size, int): raise TypeError( f"halo_size needs to be of Python type integer, {type(halo_size)} given" ) if halo_size < 0: raise ValueError( f"halo_size needs to be a non-negative Python integer, {halo_size} given" ) if self.is_distributed() and halo_size > 0: # gather lshapes lshape_map = self.lshape_map rank = self.comm.rank populated_ranks = torch.nonzero(lshape_map[:, self.split]).squeeze().tolist() if rank in populated_ranks: first_rank = populated_ranks[0] last_rank = populated_ranks[-1] if rank != last_rank: next_rank = populated_ranks[populated_ranks.index(rank) + 1] if rank != first_rank: prev_rank = populated_ranks[populated_ranks.index(rank) - 1] else: # if process has no data we ignore it return if (halo_size > self.lshape_map[:, self.split][populated_ranks]).any(): # halo_size is larger than the local size on at least one process raise ValueError( f"halo_size {halo_size} needs to be smaller than chunk-size {self.lshape[self.split]} )" ) a_prev = self.__prephalo(0, halo_size) a_next = self.__prephalo(-halo_size, None) res_prev = None res_next = None req_list = [] # exchange data with next populated process if prev: if rank != last_rank: req_list.append(self.comm.Isend(a_next, next_rank)) if rank != first_rank: res_prev = torch.empty( a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device ) req_list.append(self.comm.Irecv(res_prev, source=prev_rank)) if next: if rank != first_rank: req_list.append(self.comm.Isend(a_prev, prev_rank)) if rank != last_rank: res_next = torch.empty( a_next.size(), dtype=a_next.dtype, device=self.device.torch_device ) req_list.append(self.comm.Irecv(res_next, source=next_rank)) for req in req_list: req.Wait() self.__halo_next = res_next self.__halo_prev = res_prev self.__ishalo = True
def __cat_halo(self) -> torch.Tensor: """ Return local array concatenated to halos if they are available. """ if not self.is_distributed(): return self.__array return torch.cat( [_ for _ in (self.__halo_prev, self.__array, self.__halo_next) if _ is not None], dim=self.split, )
[docs] def __array__(self) -> np.ndarray: """ Returns a view of the process-local slice of the :class:`DNDarray` as a numpy ndarray, if the ``DNDarray`` resides on CPU. Otherwise, it returns a copy, on CPU, of the process-local slice of ``DNDarray`` as numpy ndarray. """ return self.larray.cpu().__array__()
[docs] def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """ Override NumPy's universal functions. """ import heat # TODO support ufunc method variants if method == "__call__": try: func = getattr(heat, ufunc.__name__) except AttributeError: return NotImplemented return func(*inputs, **kwargs) else: return NotImplemented
[docs] def __array_function__(self, func, types, args, kwargs): """ Augments NumPy's functions. """ import heat try: ht_func = getattr(heat, func.__name__) except AttributeError: return NotImplemented return ht_func(*args, **kwargs)
[docs] def astype(self, dtype, copy=True) -> DNDarray: """ Returns a casted version of this array. Casted array is a new array of the same shape but with given type of this array. If copy is ``True``, the same array is returned instead. Parameters ---------- dtype : datatype Heat type to which the array is cast copy : bool, optional By default the operation returns a copy of this array. If copy is set to ``False`` the cast is performed in-place and this array is returned """ dtype = canonical_heat_type(dtype) if self.__array.is_mps: if dtype == types.float64: # print warning warnings.warn( "MPS does not support float64. Casting to float32 instead.", ResourceWarning, ) dtype = types.float32 elif dtype == types.complex128: # print warning warnings.warn( "MPS does not support complex128. Casting to complex64 instead.", ResourceWarning, ) dtype = types.complex64 casted_array = self.__array.type(dtype.torch_type()) if copy: return DNDarray( casted_array, self.shape, dtype, self.split, self.device, self.comm, self.balanced ) self.__array = casted_array self.__dtype = dtype return self
[docs] def balance_(self) -> DNDarray: """ Function for balancing a :class:`DNDarray` between all nodes. To determine if this is needed use the :func:`is_balanced()` function. If the ``DNDarray`` is already balanced this function will do nothing. This function modifies the ``DNDarray`` itself and will not return anything. Examples -------- >>> a = ht.zeros((10, 2), split=0) >>> a[:, 0] = ht.arange(10) >>> b = a[3:] [0/2] tensor([[3., 0.], [1/2] tensor([[4., 0.], [5., 0.], [6., 0.]]) [2/2] tensor([[7., 0.], [8., 0.], [9., 0.]]) >>> b.balance_() >>> print(b.gshape, b.lshape) [0/2] (7, 2) (1, 2) [1/2] (7, 2) (3, 2) [2/2] (7, 2) (3, 2) >>> b [0/2] tensor([[3., 0.], [4., 0.], [5., 0.]]) [1/2] tensor([[6., 0.], [7., 0.]]) [2/2] tensor([[8., 0.], [9., 0.]]) >>> print(b.gshape, b.lshape) [0/2] (7, 2) (3, 2) [1/2] (7, 2) (2, 2) [2/2] (7, 2) (2, 2) """ if self.is_balanced(force_check=True): return self.redistribute_()
[docs] def __bool__(self) -> bool: """ Boolean scalar casting. """ return self.__cast(bool)
def __cast(self, cast_function) -> Union[float, int]: """ Implements a generic cast function for ``DNDarray`` objects. Parameters ---------- cast_function : function The actual cast function, e.g. ``float`` or ``int`` Raises ------ TypeError If the ``DNDarray`` object cannot be converted into a scalar. """ if np.prod(self.shape) == 1: if self.split is None: return cast_function(self.__array) is_empty = np.prod(self.__array.shape) == 0 root = self.comm.allreduce(0 if is_empty else self.comm.rank, op=MPI.SUM) return self.comm.bcast(None if is_empty else cast_function(self.__array), root=root) raise TypeError("only size-1 arrays can be converted to Python scalars")
[docs] def collect_(self, target_rank: Optional[int] = 0) -> None: """ A method collecting a distributed DNDarray to one MPI rank, chosen by the `target_rank` variable. It is a specific case of the ``redistribute_`` method. Parameters ---------- target_rank : int, optional The rank to which the DNDarray will be collected. Default: 0. Raises ------ TypeError If the target rank is not an integer. ValueError If the target rank is out of bounds. Examples -------- >>> st = ht.ones((50, 81, 67), split=2) >>> print(st.lshape) [0/2] (50, 81, 23) [1/2] (50, 81, 22) [2/2] (50, 81, 22) >>> st.collect_() >>> print(st.lshape) [0/2] (50, 81, 67) [1/2] (50, 81, 0) [2/2] (50, 81, 0) >>> st.collect_(1) >>> print(st.lshape) [0/2] (50, 81, 0) [1/2] (50, 81, 67) [2/2] (50, 81, 0) """ if not isinstance(target_rank, int): raise TypeError(f"target rank must be of type int , but was {type(target_rank)}") if target_rank >= self.comm.size: raise ValueError("target rank is out of bounds") if not self.is_distributed(): return target_map = self.lshape_map.clone() target_map[:, self.split] = 0 target_map[target_rank, self.split] = self.gshape[self.split] self.redistribute_(target_map=target_map)
[docs] def __complex__(self) -> DNDarray: """ Complex scalar casting. """ return self.__cast(complex)
[docs] def counts_displs(self) -> Tuple[Tuple[int], Tuple[int]]: """ Returns actual counts (number of items per process) and displacements (offsets) of the DNDarray. Does not assume load balance. """ if self.split is not None: counts = self.lshape_map[:, self.split] displs = [0] + torch.cumsum(counts, dim=0)[:-1].tolist() return tuple(counts.tolist()), tuple(displs) else: raise ValueError("Non-distributed DNDarray. Cannot calculate counts and displacements.")
[docs] def cpu(self) -> DNDarray: """ Returns a copy of this object in main memory. If this object is already in main memory, then no copy is performed and the original object is returned. """ self.__array = self.__array.cpu() self.__device = devices.cpu return self
[docs] def create_lshape_map(self, force_check: bool = False) -> torch.Tensor: """ Generate a 'map' of the lshapes of the data on all processes. Units are ``(process rank, lshape)`` Parameters ---------- force_check : bool, optional if False (default) and the lshape map has already been created, use the previous result. Otherwise, create the lshape_map """ if not force_check and self.__lshape_map is not None: return self.__lshape_map.clone() lshape_map = torch.zeros( (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device ) if not self.is_distributed: lshape_map[:] = torch.tensor(self.gshape, device=self.device.torch_device) return lshape_map if self.is_balanced(force_check=True): for i in range(self.comm.size): _, lshape, _ = self.comm.chunk(self.gshape, self.split, rank=i) lshape_map[i, :] = torch.tensor(lshape, device=self.device.torch_device) else: lshape_map[self.comm.rank, :] = torch.tensor( self.lshape, device=self.device.torch_device ) self.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) self.__lshape_map = lshape_map return lshape_map.clone()
[docs] def create_partition_interface(self): """ Create a partition interface in line with the DPPY proposal. This is subject to change. The intention of this to facilitate the usage of a general format for the referencing of distributed datasets. An example of the output and shape is shown below. __partitioned__ = { 'shape': (27, 3, 2), 'partition_tiling': (4, 1, 1), 'partitions': { (0, 0, 0): { 'start': (0, 0, 0), 'shape': (7, 3, 2), 'data': tensor([...], dtype=torch.int32), 'location': [0], 'dtype': torch.int32, 'device': 'cpu' }, (1, 0, 0): { 'start': (7, 0, 0), 'shape': (7, 3, 2), 'data': None, 'location': [1], 'dtype': torch.int32, 'device': 'cpu' }, (2, 0, 0): { 'start': (14, 0, 0), 'shape': (7, 3, 2), 'data': None, 'location': [2], 'dtype': torch.int32, 'device': 'cpu' }, (3, 0, 0): { 'start': (21, 0, 0), 'shape': (6, 3, 2), 'data': None, 'location': [3], 'dtype': torch.int32, 'device': 'cpu' } }, 'locals': [(rank, 0, 0)], 'get': lambda x: x, } Returns ------- dictionary containing the partition interface as shown above. """ lshape_map = self.create_lshape_map() start_idx_map = torch.zeros_like(lshape_map) part_tiling = [1] * self.ndim lcls = [0] * self.ndim z = torch.tensor([0], device=self.device.torch_device, dtype=self.dtype.torch_type()) if self.split is not None: starts = torch.cat((z, torch.cumsum(lshape_map[:, self.split], dim=0)[:-1]), dim=0) lcls[self.split] = self.comm.rank part_tiling[self.split] = self.comm.size else: starts = torch.zeros(self.ndim, dtype=torch.int, device=self.device.torch_device) start_idx_map[:, self.split] = starts partitions = {} base_key = [0] * self.ndim for r in range(self.comm.size): if self.split is not None: base_key[self.split] = r dat = None if r != self.comm.rank else self.larray else: dat = self.larray partitions[tuple(base_key)] = { "start": tuple(start_idx_map[r].tolist()), "shape": tuple(lshape_map[r].tolist()), "data": dat, "location": [r], "dtype": self.dtype.torch_type(), "device": self.device.torch_device, } partition_dict = { "shape": self.gshape, "partition_tiling": tuple(part_tiling), "partitions": partitions, "locals": [tuple(lcls)], "get": lambda x: x, } self.__partitions_dict__ = partition_dict return partition_dict
[docs] def __float__(self) -> DNDarray: """ Float scalar casting. See Also -------- :func:`~heat.core.manipulations.flatten` """ return self.__cast(float)
[docs] def fill_diagonal(self, value: float) -> DNDarray: """ Fill the main diagonal of a 2D :class:`DNDarray`. This function modifies the input tensor in-place, and returns the input array. Parameters ---------- value : float The value to be placed in the ``DNDarrays`` main diagonal """ # Todo: make this 3D/nD if len(self.shape) != 2: raise ValueError("Only 2D tensors supported at the moment") if self.split is not None and self.comm.is_distributed: counts, displ, _ = self.comm.counts_displs_shape(self.shape, self.split) k = min(self.shape[0], self.shape[1]) for p in range(self.comm.size): if displ[p] > k: break proc = p if self.comm.rank <= proc: indices = ( displ[self.comm.rank], displ[self.comm.rank + 1] if (self.comm.rank + 1) != self.comm.size else k, ) if self.split == 0: self.larray[:, indices[0] : indices[1]] = self.larray[ :, indices[0] : indices[1] ].fill_diagonal_(value) elif self.split == 1: self.larray[indices[0] : indices[1], :] = self.larray[ indices[0] : indices[1], : ].fill_diagonal_(value) else: self.larray = self.larray.fill_diagonal_(value) return self
[docs] def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDarray: """ Global getter function for DNDarrays. Returns a new DNDarray composed of the elements of the original tensor selected by the indices given. This does *NOT* redistribute or rebalance the resulting tensor. If the selection of values is unbalanced then the resultant tensor is also unbalanced! To redistributed the ``DNDarray`` use :func:`balance()` (issue #187) Parameters ---------- key : int, slice, Tuple[int,...], List[int,...] Indices to get from the tensor. Examples -------- >>> a = ht.arange(10, split=0) (1/2) >>> tensor([0, 1, 2, 3, 4], dtype=torch.int32) (2/2) >>> tensor([5, 6, 7, 8, 9], dtype=torch.int32) >>> a[1:6] (1/2) >>> tensor([1, 2, 3, 4], dtype=torch.int32) (2/2) >>> tensor([5], dtype=torch.int32) >>> a = ht.zeros((4, 5), split=0) (1/2) >>> tensor([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]) (2/2) >>> tensor([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]) >>> a[1:4, 1] (1/2) >>> tensor([0.]) (2/2) >>> tensor([0., 0.]) """ key = getattr(key, "copy()", key) l_dtype = self.dtype.torch_type() advanced_ind = False if isinstance(key, DNDarray) and key.ndim == self.ndim: """if the key is a DNDarray and it has as many dimensions as self, then each of the entries in the 0th dim refer to a single element. To handle this, the key is split into the torch tensors for each dimension. This signals that advanced indexing is to be used.""" # NOTE: this gathers the entire key on every process!! # TODO: remove this resplit!! key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: key = indexing.nonzero(key) if key.ndim > 1: key = list(key.larray.split(1, dim=1)) # key is now a list of tensors with dimensions (key.ndim, 1) # squeeze singleton dimension: key = [key[i].squeeze_(1) for i in range(len(key))] else: key = [key] advanced_ind = True elif not isinstance(key, tuple): """this loop handles all other cases. DNDarrays which make it to here refer to advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors are cast into lists here by PyTorch. lists mean advanced indexing will be used""" h = [slice(None, None, None)] * max(self.ndim, 1) if isinstance(key, DNDarray): key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: h[0] = torch.nonzero(key.larray).flatten() # .tolist() else: h[0] = key.larray.tolist() elif isinstance(key, torch.Tensor): if key.dtype in [torch.bool, torch.uint8]: # (coquelin77) i am not certain why this works without being a list. but it works...for now h[0] = torch.nonzero(key).flatten() # .tolist() else: h[0] = key.tolist() else: h[0] = key key = list(h) if isinstance(key, (list, tuple)): key = list(key) for i, k in enumerate(key): # this might be a good place to check if the dtype is there try: k = manipulations.resplit(k) key[i] = k.larray except AttributeError: pass # ellipsis key = list(key) key_classes = [type(n) for n in key] # if any(isinstance(n, ellipsis) for n in key): n_elips = key_classes.count(type(...)) if n_elips > 1: raise ValueError("key can only contain 1 ellipsis") elif n_elips == 1: # get which item is the ellipsis ell_ind = key_classes.index(type(...)) kst = key[:ell_ind] kend = key[ell_ind + 1 :] slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) key = kst + slices + kend else: key = key + [slice(None)] * (self.ndim - len(key)) self_proxy = self.__torch_proxy__() for i in range(len(key)): if self.__key_adds_dimension(key, i, self_proxy): key[i] = slice(None) return self.expand_dims(i)[tuple(key)] key = tuple(key) # assess final global shape gout_full = list(self_proxy[key].shape) # calculate new split axis new_split = self.split # when slicing, squeezed singleton dimensions may affect new split axis if self.split is not None and len(gout_full) < self.ndim: if advanced_ind: new_split = 0 else: for i in range(len(key[: self.split + 1])): if self.__key_is_singular(key, i, self_proxy): new_split = None if i == self.split else new_split - 1 key = tuple(key) if not self.is_distributed(): arr = self.__array[key].reshape(gout_full) return DNDarray( arr, tuple(gout_full), self.dtype, new_split, self.device, self.comm, self.balanced ) # else: (DNDarray is distributed) arr = torch.tensor([], dtype=self.__array.dtype, device=self.__array.device) rank = self.comm.rank counts, chunk_starts = self.counts_displs() counts, chunk_starts = torch.tensor(counts), torch.tensor(chunk_starts) chunk_ends = chunk_starts + counts chunk_start = chunk_starts[rank] chunk_end = chunk_ends[rank] if len(key) == 0: # handle empty list # this will return an array of shape (0, ...) arr = self.__array[key] """ At the end of the following if/elif/elif block the output array will be set. each block handles the case where the element of the key along the split axis is a different type and converts the key from global indices to local indices. """ lout = gout_full.copy() if ( isinstance(key[self.split], (list, torch.Tensor, DNDarray, np.ndarray)) and len(key[self.split]) > 1 ): # advanced indexing, elements in the split dimension are adjusted to the local indices lkey = list(key) if isinstance(key[self.split], DNDarray): lkey[self.split] = key[self.split].larray if not isinstance(lkey[self.split], torch.Tensor): inds = torch.tensor( lkey[self.split], dtype=torch.long, device=self.device.torch_device ) elif lkey[self.split].dtype in [torch.bool, torch.uint8]: # or torch.byte? # need to convert the bools to indices inds = torch.nonzero(lkey[self.split]) else: inds = lkey[self.split] # todo: remove where in favor of nonzero? might be a speed upgrade. testing required loc_inds = torch.where((inds >= chunk_start) & (inds < chunk_end)) # if there are no local indices on a process, then `arr` is empty # if local indices exist: if len(loc_inds[0]) != 0: # select same local indices for other (non-split) dimensions if necessary for i, k in enumerate(lkey): if isinstance(k, (list, torch.Tensor, DNDarray)) and i != self.split: lkey[i] = k[loc_inds] # correct local indices for offset inds = inds[loc_inds] - chunk_start lkey[self.split] = inds lout[new_split] = len(inds) arr = self.__array[tuple(lkey)].reshape(tuple(lout)) elif len(loc_inds[0]) == 0: if new_split is not None: lout[new_split] = len(loc_inds[0]) else: lout = [0] * len(gout_full) arr = torch.tensor([], dtype=self.larray.dtype, device=self.larray.device).reshape( tuple(lout) ) elif isinstance(key[self.split], slice): # standard slicing along the split axis, # adjust the slice start, stop, and step, then run it on the processes which have the requested data key = list(key) key[self.split] = stride_tricks.sanitize_slice(key[self.split], self.gshape[self.split]) key_start, key_stop, key_step = ( key[self.split].start, key[self.split].stop, key[self.split].step, ) og_key_start = key_start st_pr = torch.where(key_start < chunk_ends)[0] st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size sp_pr = torch.where(key_stop >= chunk_starts)[0] sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 actives = list(range(st_pr, sp_pr + 1)) if rank in actives: key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] key_start, key_stop = self.__xitem_get_key_start_stop( rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start ) key[self.split] = slice(key_start, key_stop, key_step) lout[new_split] = ( math.ceil((key_stop - key_start) / key_step) if key_step is not None else key_stop - key_start ) arr = self.__array[tuple(key)].reshape(lout) else: lout[new_split] = 0 arr = torch.empty(lout, dtype=self.__array.dtype, device=self.__array.device) elif self.__key_is_singular(key, self.split, self_proxy): # getting one item along split axis: key = list(key) if isinstance(key[self.split], list): key[self.split] = key[self.split].pop() elif isinstance(key[self.split], (torch.Tensor, DNDarray, np.ndarray)): key[self.split] = key[self.split].item() # translate negative index if key[self.split] < 0: key[self.split] += self.gshape[self.split] active_rank = torch.where(key[self.split] >= chunk_starts)[0][-1].item() # slice `self` on `active_rank`, allocate `arr` on all other ranks in preparation for Bcast if rank == active_rank: key[self.split] -= chunk_start.item() arr = self.__array[tuple(key)].reshape(tuple(lout)) else: arr = torch.empty(tuple(lout), dtype=self.larray.dtype, device=self.larray.device) # broadcast result # TODO: Replace with `self.comm.Bcast(arr, root=active_rank)` after fixing #784 arr = self.comm.bcast(arr, root=active_rank) if arr.device != self.larray.device: # todo: remove when unnecessary (also after #784) arr = arr.to(device=self.larray.device) return DNDarray( arr.type(l_dtype), gout_full if isinstance(gout_full, tuple) else tuple(gout_full), self.dtype, new_split, self.device, self.comm, balanced=True if new_split is None else None, )
if torch.cuda.device_count() > 0: def gpu(self) -> DNDarray: """ Returns a copy of this object in GPU memory. If this object is already in GPU memory, then no copy is performed and the original object is returned. """ self.__array = self.__array.cuda(devices.gpu.torch_device) self.__device = devices.gpu return self
[docs] def __int__(self) -> DNDarray: """ Integer scalar casting. """ return self.__cast(int)
[docs] def is_balanced(self, force_check: bool = False) -> bool: """ Determine if ``self`` is balanced evenly (or as evenly as possible) across all nodes distributed evenly (or as evenly as possible) across all processes. This is equivalent to returning ``self.balanced``. If no information is available (``self.balanced = None``), the balanced status will be assessed via collective communication. Parameters ---------- force_check : bool, optional If True, the balanced status of the ``DNDarray`` will be assessed via collective communication in any case. """ if not force_check and self.balanced is not None: return self.balanced _, _, chk = self.comm.chunk(self.shape, self.split) test_lshape = tuple([x.stop - x.start for x in chk]) balanced = 1 if test_lshape == self.lshape else 0 out = self.comm.allreduce(balanced, MPI.SUM) balanced = True if out == self.comm.size else False return balanced
[docs] def is_distributed(self) -> bool: """ Determines whether the data of this ``DNDarray`` is distributed across multiple processes. """ return self.split is not None and self.comm.is_distributed()
@staticmethod def __key_is_singular(key: any, axis: int, self_proxy: torch.Tensor) -> bool: # determine if the key gets a singular item zeros = (0,) * (self_proxy.ndim - 1) return self_proxy[(*zeros[:axis], key[axis], *zeros[axis:])].ndim == 0 @staticmethod def __key_adds_dimension(key: any, axis: int, self_proxy: torch.Tensor) -> bool: # determine if the key adds a new dimension zeros = (0,) * (self_proxy.ndim - 1) return self_proxy[(*zeros[:axis], key[axis], *zeros[axis:])].ndim == 2
[docs] def item(self): """ Returns the only element of a 1-element :class:`DNDarray`. Mirror of the pytorch command by the same name. If size of ``DNDarray`` is >1 element, then a ``ValueError`` is raised (by pytorch) Examples -------- >>> import heat as ht >>> x = ht.zeros((1)) >>> x.item() 0.0 """ if self.size > 1: raise ValueError("only one-element DNDarrays can be converted to Python scalars") # make sure the element is on every process self.resplit_(None) return self.__array.item()
[docs] def __len__(self) -> int: """ The length of the ``DNDarray``, i.e. the number of items in the first dimension. """ return self.shape[0]
[docs] def numpy(self) -> np.array: """ Returns a copy of the :class:`DNDarray` as numpy ndarray. If the ``DNDarray`` resides on the GPU, the underlying data will be copied to the CPU first. If the ``DNDarray`` is distributed, an MPI Allgather operation will be performed before converting to np.ndarray, i.e. each MPI process will end up holding a copy of the entire array in memory. Make sure process memory is sufficient! Examples -------- >>> import heat as ht T1 = ht.random.randn((10,8)) T1.numpy() """ dist = self.copy().resplit_(axis=None) return dist.larray.cpu().numpy()
[docs] def _repr_pretty_(self, p, cycle): """ Pretty print for IPython. """ if cycle: p.text(printing.__str__(self)) else: p.text(printing.__str__(self))
[docs] def __repr__(self) -> str: """ Returns a printable representation of the passed DNDarray, targeting developers. """ return printing.__repr__(self)
[docs] def ravel(self): """ Flattens the ``DNDarray``. See Also -------- :func:`~heat.core.manipulations.ravel` Examples -------- >>> a = ht.ones((2, 3), split=0) >>> b = a.ravel() >>> a[0, 0] = 4 >>> b DNDarray([4., 1., 1., 1., 1., 1.], dtype=ht.float32, device=cpu:0, split=0) """ return manipulations.ravel(self)
[docs] def redistribute_( self, lshape_map: Optional[torch.Tensor] = None, target_map: Optional[torch.Tensor] = None ): """ Redistributes the data of the :class:`DNDarray` *along the split axis* to match the given target map. This function does not modify the non-split dimensions of the ``DNDarray``. This is an abstraction and extension of the balance function. Parameters ---------- lshape_map : torch.Tensor, optional The current lshape of processes. Units are ``[rank, lshape]``. target_map : torch.Tensor, optional The desired distribution across the processes. Units are ``[rank, target lshape]``. Note: the only important parts of the target map are the values along the split axis, values which are not along this axis are there to mimic the shape of the ``lshape_map``. Examples -------- >>> st = ht.ones((50, 81, 67), split=2) >>> target_map = torch.zeros((st.comm.size, 3), dtype=torch.int64) >>> target_map[0, 2] = 67 >>> print(target_map) [0/2] tensor([[ 0, 0, 67], [0/2] [ 0, 0, 0], [0/2] [ 0, 0, 0]], dtype=torch.int32) [1/2] tensor([[ 0, 0, 67], [1/2] [ 0, 0, 0], [1/2] [ 0, 0, 0]], dtype=torch.int32) [2/2] tensor([[ 0, 0, 67], [2/2] [ 0, 0, 0], [2/2] [ 0, 0, 0]], dtype=torch.int32) >>> print(st.lshape) [0/2] (50, 81, 23) [1/2] (50, 81, 22) [2/2] (50, 81, 22) >>> st.redistribute_(target_map=target_map) >>> print(st.lshape) [0/2] (50, 81, 67) [1/2] (50, 81, 0) [2/2] (50, 81, 0) """ if not self.is_distributed(): return snd_dtype = self.dtype.torch_type() # units -> {pr, 1st index, 2nd index} if lshape_map is None: # NOTE: giving an lshape map which is incorrect will result in an incorrect distribution lshape_map = self.create_lshape_map(force_check=True) else: if not isinstance(lshape_map, torch.Tensor): raise TypeError(f"lshape_map must be a torch.Tensor, currently {type(lshape_map)}") if lshape_map.shape != (self.comm.size, len(self.gshape)): raise ValueError( f"lshape_map must have the shape ({self.comm.size}, {len(self.gshape)}), currently {lshape_map.shape}" ) if target_map is None: # if no target map is given then it will balance the tensor _, _, chk = self.comm.chunk(self.shape, self.split) target_map = lshape_map.clone() target_map[..., self.split] = 0 for pr in range(self.comm.size): target_map[pr, self.split] = self.comm.chunk(self.shape, self.split, rank=pr)[1][ self.split ] self.__balanced = True else: sanitation.sanitize_in_tensor(target_map) if target_map[..., self.split].sum() != self.shape[self.split]: raise ValueError( f"Sum along the split axis of the target map must be equal to the shape in that dimension, currently {target_map[..., self.split]}" ) if target_map.shape != (self.comm.size, len(self.gshape)): raise ValueError( f"target_map must have the shape {(self.comm.size, len(self.gshape))}, currently {target_map.shape}" ) # no info on balanced status self.__balanced = False lshape_cumsum = torch.cumsum(lshape_map[..., self.split], dim=0) chunk_cumsum = torch.cat( ( torch.tensor([0], device=self.device.torch_device), torch.cumsum(target_map[..., self.split], dim=0), ), dim=0, ) # need the data start as well for process 0 for rcv_pr in range(self.comm.size - 1): st = chunk_cumsum[rcv_pr].item() sp = chunk_cumsum[rcv_pr + 1].item() # start pr should be the next process with data if lshape_map[rcv_pr, self.split] >= target_map[rcv_pr, self.split]: # if there is more data on the process than the start process than start == stop st_pr = rcv_pr sp_pr = rcv_pr else: # if there is less data on the process than need to get the data from the next data # with data # need processes > rcv_pr with lshape > 0 st_pr = ( torch.nonzero(input=lshape_map[rcv_pr:, self.split] > 0, as_tuple=False)[ 0 ].item() + rcv_pr ) hld = ( torch.nonzero(input=sp <= lshape_cumsum[rcv_pr:], as_tuple=False).flatten() + rcv_pr ) sp_pr = hld[0].item() if hld.numel() > 0 else self.comm.size # st_pr and sp_pr are the processes on which the data sits at the beginning # need to loop from st_pr to sp_pr + 1 and send the pr for snd_pr in range(st_pr, sp_pr + 1): if snd_pr == self.comm.size: break data_required = abs(sp - st - lshape_map[rcv_pr, self.split].item()) send_amt = ( data_required if data_required <= lshape_map[snd_pr, self.split] else lshape_map[snd_pr, self.split] ) if (sp - st) <= lshape_map[rcv_pr, self.split].item() or snd_pr == rcv_pr: send_amt = 0 # send amount is the data still needed by recv if that is available on the snd if send_amt != 0: self.__redistribute_shuffle( snd_pr=snd_pr, send_amt=send_amt, rcv_pr=rcv_pr, snd_dtype=snd_dtype ) lshape_cumsum[snd_pr] -= send_amt lshape_cumsum[rcv_pr] += send_amt lshape_map[rcv_pr, self.split] += send_amt lshape_map[snd_pr, self.split] -= send_amt if lshape_map[rcv_pr, self.split] > target_map[rcv_pr, self.split]: # if there is any data left on the process then send it to the next one send_amt = lshape_map[rcv_pr, self.split] - target_map[rcv_pr, self.split] self.__redistribute_shuffle( snd_pr=rcv_pr, send_amt=send_amt.item(), rcv_pr=rcv_pr + 1, snd_dtype=snd_dtype ) lshape_cumsum[rcv_pr] -= send_amt lshape_cumsum[rcv_pr + 1] += send_amt lshape_map[rcv_pr, self.split] -= send_amt lshape_map[rcv_pr + 1, self.split] += send_amt if any(lshape_map[..., self.split] != target_map[..., self.split]): # sometimes need to call the redistribute once more, # (in the case that the second to last processes needs to get data from +1 and -1) self.redistribute_(lshape_map=lshape_map, target_map=target_map) self.__lshape_map = target_map
def __redistribute_shuffle( self, snd_pr: Union[int, torch.Tensor], send_amt: Union[int, torch.Tensor], rcv_pr: Union[int, torch.Tensor], snd_dtype: torch.dtype, ): """ Function to abstract the function used during redistribute for shuffling data between processes along the split axis Parameters ---------- snd_pr : int or torch.Tensor Sending process send_amt : int or torch.Tensor Amount of data to be sent by the sending process rcv_pr : int or torch.Tensor Receiving process snd_dtype : torch.dtype Torch type of the data in question """ rank = self.comm.rank send_slice = [slice(None)] * self.ndim keep_slice = [slice(None)] * self.ndim if rank == snd_pr: if snd_pr < rcv_pr: # data passed to a higher rank (off the bottom) send_slice[self.split] = slice( self.lshape[self.split] - send_amt, self.lshape[self.split] ) keep_slice[self.split] = slice(0, self.lshape[self.split] - send_amt) if snd_pr > rcv_pr: # data passed to a lower rank (off the top) send_slice[self.split] = slice(0, send_amt) keep_slice[self.split] = slice(send_amt, self.lshape[self.split]) data = self.__array[tuple(send_slice)].clone() self.comm.Send(data, dest=rcv_pr, tag=685) self.__array = self.__array[tuple(keep_slice)] if rank == rcv_pr: shp = list(self.gshape) shp[self.split] = send_amt data = torch.zeros(shp, dtype=snd_dtype, device=self.device.torch_device) self.comm.Recv(data, source=snd_pr, tag=685) if snd_pr < rcv_pr: # data passed from a lower rank (append to top) self.__array = torch.cat((data, self.__array), dim=self.split) if snd_pr > rcv_pr: # data passed from a higher rank (append to bottom) self.__array = torch.cat((self.__array, data), dim=self.split)
[docs] def resplit_(self, axis: int = None): """ In-place option for resplitting a :class:`DNDarray`. Parameters ---------- axis : int The new split axis, ``None`` denotes gathering, an int will set the new split axis Examples -------- >>> a = ht.zeros( ... ( ... 4, ... 5, ... ), ... split=0, ... ) >>> a.lshape (0/2) (2, 5) (1/2) (2, 5) >>> ht.resplit_(a, None) >>> a.split None >>> a.lshape (0/2) (4, 5) (1/2) (4, 5) >>> a = ht.zeros( ... ( ... 4, ... 5, ... ), ... split=0, ... ) >>> a.lshape (0/2) (2, 5) (1/2) (2, 5) >>> ht.resplit_(a, 1) >>> a.split 1 >>> a.lshape (0/2) (4, 3) (1/2) (4, 2) """ # sanitize the axis to check whether it is in range axis = sanitize_axis(self.shape, axis) # early out for unchanged content if self.comm.size == 1: self.__split = axis if axis is None: self.__partitions_dict__ = None if axis == self.split: return self self.__partitions_dict__ = None if axis is None: gathered = torch.empty( self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device ) counts, displs = self.counts_displs() self.comm.Allgatherv(self.__array, (gathered, counts, displs), recv_axis=self.split) self.__array = gathered self.__split = axis self.__lshape_map = None return self # tensor needs be split/sliced locally if self.split is None: # new_arr = self _, _, slices = self.comm.chunk(self.shape, axis) temp = self.__array[slices] self.__array = torch.empty((1,), device=self.device.torch_device) # necessary to clear storage of local __array self.__array = temp.clone().detach() self.__split = axis self.__lshape_map = None return self arr_tiles = tiling.SplitTiles(self) new_tiles = tiling.SplitTiles(self) gshape = self.shape new_lshape = list(gshape) new_lshape[axis] = int(arr_tiles.tile_dimensions[axis][self.comm.rank].item()) recv_buffer = torch.empty( tuple(new_lshape), dtype=self.dtype.torch_type(), device=self.device.torch_device ) self._axis2axisResplit( self.larray, self.split, arr_tiles, recv_buffer, axis, new_tiles, self.comm ) self.__array = recv_buffer self.__split = axis self.__lshape_map = None return self
[docs] def __setitem__( self, key: Union[int, Tuple[int, ...], List[int, ...]], value: Union[float, DNDarray, torch.Tensor], ): """ Global item setter Parameters ---------- key : Union[int, Tuple[int,...], List[int,...]] Index/indices to be set value: Union[float, DNDarray,torch.Tensor] Value to be set to the specified positions in the DNDarray (self) Notes ----- If a ``DNDarray`` is given as the value to be set then the split axes are assumed to be equal. If they are not, PyTorch will raise an error when the values are attempted to be set on the local array Examples -------- >>> a = ht.zeros((4, 5), split=0) (1/2) >>> tensor([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]) (2/2) >>> tensor([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]) >>> a[1:4, 1] = 1 >>> a (1/2) >>> tensor([[0., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]]) (2/2) >>> tensor([[0., 1., 0., 0., 0.], [0., 1., 0., 0., 0.]]) """ key = getattr(key, "copy()", key) try: if value.split != self.split: val_split = int(value.split) sp = self.split warnings.warn( f"\nvalue.split {val_split} not equal to this DNDarray's split:" f" {sp}. this may cause errors or unwanted behavior", category=RuntimeWarning, ) except (AttributeError, TypeError): pass # NOTE: for whatever reason, there is an inplace op which interferes with the abstraction # of this next block of code. this is shared with __getitem__. I attempted to abstract it # in a standard way, but it was causing errors in the test suite. If someone else is # motived to do this they are welcome to, but i have no time right now # print(key) if isinstance(key, DNDarray) and key.ndim == self.ndim: """if the key is a DNDarray and it has as many dimensions as self, then each of the entries in the 0th dim refer to a single element. To handle this, the key is split into the torch tensors for each dimension. This signals that advanced indexing is to be used.""" key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: key = indexing.nonzero(key) if key.ndim > 1: key = list(key.larray.split(1, dim=1)) # key is now a list of tensors with dimensions (key.ndim, 1) # squeeze singleton dimension: key = [key[i].squeeze_(1) for i in range(len(key))] else: key = [key] elif not isinstance(key, tuple): """this loop handles all other cases. DNDarrays which make it to here refer to advanced indexing slices, as do the torch tensors. Both DNDaarrys and torch.Tensors are cast into lists here by PyTorch. lists mean advanced indexing will be used""" h = [slice(None, None, None)] * self.ndim if isinstance(key, DNDarray): key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: h[0] = torch.nonzero(key.larray).flatten() # .tolist() else: h[0] = key.larray.tolist() elif isinstance(key, torch.Tensor): if key.dtype in [torch.bool, torch.uint8]: # (coquelin77) im not sure why this works without being a list...but it does...for now h[0] = torch.nonzero(key).flatten() # .tolist() else: h[0] = key.tolist() else: h[0] = key key = list(h) # key must be torch-proof if isinstance(key, (list, tuple)): key = list(key) for i, k in enumerate(key): try: # extract torch tensor k = manipulations.resplit(k) key[i] = k.larray except AttributeError: pass # remove bools from a torch tensor in favor of indexes try: if key[i].dtype in [torch.bool, torch.uint8]: key[i] = torch.nonzero(key[i]).flatten() except (AttributeError, TypeError): pass key = list(key) # ellipsis stuff key_classes = [type(n) for n in key] # if any(isinstance(n, ellipsis) for n in key): n_elips = key_classes.count(type(...)) if n_elips > 1: raise ValueError("key can only contain 1 ellipsis") elif n_elips == 1: # get which item is the ellipsis ell_ind = key_classes.index(type(...)) kst = key[:ell_ind] kend = key[ell_ind + 1 :] slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) key = kst + slices + kend # ---------- end ellipsis stuff ------------- for c, k in enumerate(key): try: key[c] = k.item() except (AttributeError, ValueError, RuntimeError): pass rank = self.comm.rank if self.split is not None: counts, chunk_starts = self.counts_displs() else: counts, chunk_starts = 0, [0] * self.comm.size counts = torch.tensor(counts, device=self.device.torch_device) chunk_starts = torch.tensor(chunk_starts, device=self.device.torch_device) chunk_ends = chunk_starts + counts chunk_start = chunk_starts[rank] chunk_end = chunk_ends[rank] # determine which elements are on the local process (if the key is a torch tensor) try: # if isinstance(key[self.split], torch.Tensor): filter_key = torch.nonzero( (chunk_start <= key[self.split]) & (key[self.split] < chunk_end) ) for k in range(len(key)): try: key[k] = key[k][filter_key].flatten() except TypeError: pass except TypeError: # this will happen if the key doesnt have that many pass key = tuple(key) if not self.is_distributed(): return self.__setter(key, value) # returns None # raise RuntimeError("split axis of array and the target value are not equal") removed # this will occur if the local shapes do not match rank = self.comm.rank ends = [] for pr in range(self.comm.size): _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) ends.append(e[self.split].stop - e[self.split].start) ends = torch.tensor(ends, device=self.device.torch_device) chunk_ends = ends.cumsum(dim=0) chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) _, _, chunk_slice = self.comm.chunk(self.shape, self.split) chunk_start = chunk_slice[self.split].start chunk_end = chunk_slice[self.split].stop self_proxy = self.__torch_proxy__() # if the value is a DNDarray, the divisions need to be balanced: # this means that we need to know how much data is where for both DNDarrays # if the value data is not in the right place, then it will need to be moved if isinstance(key[self.split], slice): key = list(key) key_start = key[self.split].start if key[self.split].start is not None else 0 key_stop = ( key[self.split].stop if key[self.split].stop is not None else self.gshape[self.split] ) if key_stop < 0: key_stop = self.gshape[self.split] + key[self.split].stop key_step = key[self.split].step og_key_start = key_start st_pr = torch.where(key_start < chunk_ends)[0] st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size sp_pr = torch.where(key_stop >= chunk_starts)[0] sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 actives = list(range(st_pr, sp_pr + 1)) if ( isinstance(value, type(self)) and value.split is not None and value.shape[self.split] != self.shape[self.split] ): # setting elements in self with a DNDarray which is not the same size in the # split dimension local_keys = [] # below is used if the target needs to be reshaped target_reshape_map = torch.zeros( (self.comm.size, self.ndim), dtype=torch.int64, device=self.device.torch_device ) for r in range(self.comm.size): if r not in actives: loc_key = key.copy() loc_key[self.split] = slice(0, 0, 0) else: key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] key_start_l, key_stop_l = self.__xitem_get_key_start_stop( r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start ) loc_key = key.copy() loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) gout_full = torch.tensor( self_proxy[tuple(loc_key)].shape, device=self.device.torch_device ) target_reshape_map[r] = gout_full local_keys.append(loc_key) key = local_keys[rank] value = value.redistribute(target_map=target_reshape_map) if rank not in actives: return # non-active ranks can exit here chunk_starts_v = target_reshape_map[:, self.split] value_slice = [slice(None, None, None)] * value.ndim step2 = key_step if key_step is not None else 1 key_start = (chunk_starts_v[rank] - og_key_start).item() key_start = max(key_start, 0) key_stop = key_start + key_stop slice_loc = min(self.split, value.ndim - 1) value_slice[slice_loc] = slice( key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 ) self.__setter(tuple(key), value.larray) return # if rank in actives: if rank not in actives: return # non-active ranks can exit here key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] key_start, key_stop = self.__xitem_get_key_start_stop( rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start ) key[self.split] = slice(key_start, key_stop, key_step) # todo: need to slice the values to be the right size... if isinstance(value, (torch.Tensor, type(self))): # if its a torch tensor, it is assumed to exist on all processes value_slice = [slice(None, None, None)] * value.ndim step2 = key_step if key_step is not None else 1 key_start = (chunk_starts[rank] - og_key_start).item() key_start = max(key_start, 0) key_stop = key_start + key_stop slice_loc = min(self.split, value.ndim - 1) value_slice[slice_loc] = slice( key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 ) self.__setter(tuple(key), value[tuple(value_slice)]) else: self.__setter(tuple(key), value) elif isinstance(key[self.split], (torch.Tensor, list)): key = list(key) key[self.split] -= chunk_start if len(key[self.split]) != 0: self.__setter(tuple(key), value) elif key[self.split] in range(chunk_start, chunk_end): key = list(key) key[self.split] = key[self.split] - chunk_start self.__setter(tuple(key), value) elif key[self.split] < 0: key = list(key) if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): key[self.split] = key[self.split] + self.shape[self.split] - chunk_start self.__setter(tuple(key), value)
def __setter( self, key: Union[int, Tuple[int, ...], List[int, ...]], value: Union[float, DNDarray, torch.Tensor], ): """ Utility function for checking ``value`` and forwarding to :func:``__setitem__`` Raises ------ NotImplementedError If the type of ``value`` ist not supported """ if np.isscalar(value): self.__array.__setitem__(key, value) elif isinstance(value, DNDarray): self.__array.__setitem__(key, value.__array) elif isinstance(value, torch.Tensor): self.__array.__setitem__(key, value.data) elif isinstance(value, (list, tuple)): value = torch.tensor(value, device=self.device.torch_device) self.__array.__setitem__(key, value.data) elif isinstance(value, np.ndarray): value = torch.from_numpy(value) self.__array.__setitem__(key, value.data) else: raise NotImplementedError(f"Not implemented for {value.__class__.__name__}")
[docs] def __str__(self) -> str: """ Computes a string representation of the passed ``DNDarray``. """ return printing.__str__(self)
[docs] def tolist(self, keepsplit: bool = False) -> List: """ Return a copy of the local array data as a (nested) Python list. For scalars, a standard Python number is returned. Parameters ---------- keepsplit: bool Whether the list should be returned locally or globally. Examples -------- >>> a = ht.array([[0, 1], [2, 3]]) >>> a.tolist() [[0, 1], [2, 3]] >>> a = ht.array([[0, 1], [2, 3]], split=0) >>> a.tolist() [[0, 1], [2, 3]] >>> a = ht.array([[0, 1], [2, 3]], split=1) >>> a.tolist(keepsplit=True) (1/2) [[0], [2]] (2/2) [[1], [3]] """ if not keepsplit: return self.resplit(axis=None).__array.tolist() return self.__array.tolist()
[docs] @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """ Supports PyTorch's dispatch mechanism. """ import heat if kwargs is None: kwargs = {} try: ht_func = getattr(heat, func.__name__) except AttributeError: return NotImplemented return ht_func(*args, **kwargs)
[docs] def __torch_proxy__(self) -> torch.Tensor: """ Return a 1-element `torch.Tensor` strided as the global `self` shape. Used internally for sanitation purposes. """ return torch.ones((1,), dtype=torch.int8, device=self.larray.device).as_strided( self.gshape, [0] * self.ndim )
@staticmethod def __xitem_get_key_start_stop( rank: int, actives: list, key_st: int, key_sp: int, step: int, ends: torch.Tensor, og_key_st: int, ) -> Tuple[int, int]: # this does some basic logic for adjusting the starting and stoping of the a key for # setitem and getitem if step is not None and rank > actives[0]: offset = (ends[rank - 1] - og_key_st) % step if step > 2 and offset > 0: key_st += step - offset elif step == 2 and offset > 0: key_st += (ends[rank - 1] - og_key_st) % step if isinstance(key_st, torch.Tensor): key_st = key_st.item() if isinstance(key_sp, torch.Tensor): key_sp = key_sp.item() return key_st, key_sp
# HeAT imports at the end to break cyclic dependencies from . import complex_math from . import devices from . import factories from . import indexing from . import linalg from . import manipulations from . import printing from . import rounding from . import sanitation from . import statistics from . import stride_tricks from . import tiling from . import types from .devices import Device from .stride_tricks import sanitize_axis from .types import datatype, canonical_heat_type