"""
Functions relating to indices of items within DNDarrays, i.e. `where()`
"""
import torch
from typing import List, Dict, Any, TypeVar, Union, Tuple, Sequence
from .communication import MPI
from .dndarray import DNDarray
from . import sanitation
from . import types
__all__ = ["nonzero", "where"]
[docs]
def nonzero(x: DNDarray) -> DNDarray:
"""
Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero (using ``torch.nonzero``).
If ``x`` is split then the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray`
can be UNBALANCED as it contains the indices of the non-zero elements on each node.
Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension.
The values in ``x`` are always tested and returned in row-major, C-style order.
The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``.
Parameters
----------
x: DNDarray
Input array
Examples
--------
>>> import heat as ht
>>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0)
>>> ht.nonzero(x)
DNDarray([[0, 0],
[1, 1],
[1, 2],
[2, 1]], dtype=ht.int64, device=cpu:0, split=0)
>>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0)
>>> y > 3
DNDarray([[False, False, False],
[ True, True, True],
[ True, True, True]], dtype=ht.bool, device=cpu:0, split=0)
>>> ht.nonzero(y > 3)
DNDarray([[1, 0],
[1, 1],
[1, 2],
[2, 0],
[2, 1],
[2, 2]], dtype=ht.int64, device=cpu:0, split=0)
>>> y[ht.nonzero(y > 3)]
DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0)
"""
sanitation.sanitize_in(x)
lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False)
# add offsets mapping from local indices to global indices if x is split
if x.split is not None:
_, _, slices = x.comm.chunk(x.shape, x.split)
lcl_nonzero[..., x.split] += slices[x.split].start
if x.ndim == 1:
lcl_nonzero = lcl_nonzero.squeeze(dim=1)
# compute global shape of the index array
gout = list(lcl_nonzero.shape)
if x.split is None:
is_split = None
else:
gout[0] = x.comm.allreduce(gout[0], MPI.SUM)
is_split = 0
return DNDarray(
lcl_nonzero,
gshape=tuple(gout),
dtype=types.canonical_heat_type(lcl_nonzero.dtype),
split=is_split,
device=x.device,
comm=x.comm,
balanced=False,
)
DNDarray.nonzero = lambda self: nonzero(self)
DNDarray.nonzero.__doc__ = nonzero.__doc__
[docs]
def where(
cond: DNDarray,
x: Union[None, int, float, DNDarray] = None,
y: Union[None, int, float, DNDarray] = None,
) -> DNDarray:
"""
Return a :class:`~heat.core.dndarray.DNDarray` containing elements chosen from ``x`` or ``y`` depending on condition.
Result is a :class:`~heat.core.dndarray.DNDarray` with elements from ``x`` where cond is ``True``,
and elements from ``y`` elsewhere (``False``).
Parameters
----------
cond : DNDarray
Condition of interest, where true yield ``x`` otherwise yield ``y``
x : DNDarray or int or float, optional
Values from which to choose. ``x``, ``y`` and condition need to be broadcastable to some shape.
y : DNDarray or int or float, optional
Values from which to choose. ``x``, ``y`` and condition need to be broadcastable to some shape.
Raises
------
NotImplementedError
if splits of the two input :class:`~heat.core.dndarray.DNDarray` differ
TypeError
if only x or y is given or both are not DNDarrays or numerical scalars
Notes
-----
When only condition is provided, this function is a shorthand for :func:`nonzero`.
Examples
--------
>>> import heat as ht
>>> x = ht.arange(10, split=0)
>>> ht.where(x < 5, x, 10 * x)
DNDarray([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90], dtype=ht.int64, device=cpu:0, split=0)
>>> y = ht.array([[0, 1, 2], [0, 2, 4], [0, 3, 6]])
>>> ht.where(y < 4, y, -1)
DNDarray([[ 0, 1, 2],
[ 0, 2, -1],
[ 0, 3, -1]], dtype=ht.int64, device=cpu:0, split=None)
"""
if cond.split is not None and (isinstance(x, DNDarray) or isinstance(y, DNDarray)):
if (isinstance(x, DNDarray) and cond.split != x.split) or (
isinstance(y, DNDarray) and cond.split != y.split
):
if len(y.shape) >= 1 and y.shape[0] > 1:
raise NotImplementedError("binary op not implemented for different split axes")
if isinstance(x, (DNDarray, int, float)) and isinstance(y, (DNDarray, int, float)):
for var in [x, y]:
if isinstance(var, int):
var = float(var)
return cond.dtype(cond == 0) * y + cond * x
elif x is None and y is None:
return nonzero(cond)
else:
raise TypeError(
f"either both or neither x and y must be given and both must be DNDarrays or numerical scalars({type(x)}, {type(y)})"
)