Source code for heat.core.relational

"""
Functions for relational oprations, i.e. equal/no equal...
"""

from __future__ import annotations

import torch
import numpy as np

from typing import Union

from .communication import MPI
from .dndarray import DNDarray
from . import _operations
from . import dndarray
from . import types
from . import sanitation
from . import factories
from . import devices

__all__ = [
    "eq",
    "equal",
    "ge",
    "greater",
    "greater_equal",
    "gt",
    "le",
    "less",
    "less_equal",
    "lt",
    "ne",
    "not_equal",
]


[docs] def eq(x, y) -> DNDarray: """ Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise comparision. Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument. Returns False if the operands are not scalars or :class:`~heat.core.dndarray.DNDarray` Parameters ---------- x: DNDarray or scalar The first operand involved in the comparison y: DNDarray or scalar The second operand involved in the comparison Examples -------- >>> import heat as ht >>> x = ht.float32([[1, 2], [3, 4]]) >>> ht.eq(x, 3.0) DNDarray([[False, False], [ True, False]], dtype=ht.bool, device=cpu:0, split=None) >>> y = ht.float32([[2, 2], [2, 2]]) >>> ht.eq(x, y) DNDarray([[False, True], [False, False]], dtype=ht.bool, device=cpu:0, split=None) >>> ht.eq(x, slice(None)) False """ try: res = _operations.__binary_op(torch.eq, x, y) if res.dtype != types.bool: res = dndarray.DNDarray( res.larray.type(torch.bool), res.gshape, types.bool, res.split, res.device, res.comm, res.balanced, ) return res except (TypeError, ValueError): return False
DNDarray.__eq__ = lambda self, other: eq(self, other) DNDarray.__eq__.__doc__ = eq.__doc__
[docs] def equal(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> bool: """ Overall comparison of equality between two :class:`~heat.core.dndarray.DNDarray`. Returns ``True`` if two arrays have the same size and elements, and ``False`` otherwise. Parameters ---------- x: DNDarray or scalar The first operand involved in the comparison y: DNDarray or scalar The second operand involved in the comparison Examples -------- >>> import heat as ht >>> x = ht.float32([[1, 2], [3, 4]]) >>> ht.equal(x, ht.float32([[1, 2], [3, 4]])) True >>> y = ht.float32([[2, 2], [2, 2]]) >>> ht.equal(x, y) False >>> ht.equal(x, 3.0) False """ if np.isscalar(x) and np.isscalar(y): x = factories.array(x) y = factories.array(y) elif isinstance(x, DNDarray) and np.isscalar(y): if x.gnumel == 1: return equal(x.item(), y) return False # y = factories.full_like(x, fill_value=y) elif np.isscalar(x) and isinstance(y, DNDarray): if y.gnumel == 1: return equal(x, y.item()) return False # x = factories.full_like(y, fill_value=x) else: # elif isinstance(x, DNDarray) and isinstance(y, DNDarray): if x.gnumel == 1: return equal(x.item(), y) elif y.gnumel == 1: return equal(x, y.item()) elif x.comm != y.comm: raise NotImplementedError("Not implemented for other comms") elif x.gshape != y.gshape: return False if x.split is None and y.split is None: pass elif x.split is None and y.split is not None: if y.is_balanced(force_check=False): x = factories.array(x, split=y.split, copy=False, comm=x.comm, device=x.device) else: target_map = y.lshape_map idx = [slice(None)] * x.ndim idx[y.split] = slice( target_map[: x.comm.rank, y.split].sum(), target_map[: x.comm.rank + 1, y.split].sum(), ) x = factories.array( x.larray[tuple(idx)], is_split=y.split, copy=False, comm=x.comm, device=x.device, ) elif x.split is not None and y.split is None: if x.is_balanced(force_check=False): y = factories.array(y, split=x.split, copy=False, comm=y.comm, device=y.device) else: target_map = x.lshape_map idx = [slice(None)] * y.ndim idx[x.split] = slice( target_map[: y.comm.rank, x.split].sum(), target_map[: y.comm.rank + 1, x.split].sum(), ) y = factories.array( y.larray[tuple(idx)], is_split=x.split, copy=False, comm=y.comm, device=y.device, ) elif x.split != y.split: raise ValueError( "DNDarrays must have the same split axes, found {x.split} and {y.split}" ) elif not (x.is_balanced(force_check=False) and y.is_balanced(force_check=False)): x_lmap = x.lshape_map y_lmap = y.lshape_map if not torch.equal(x_lmap, y_lmap): x = x.balance() y = y.balance() result_type = types.result_type(x, y) is_mps = x.larray.is_mps or y.larray.is_mps if is_mps and result_type is types.float64: result_type = types.float32 x = x.astype(result_type) y = y.astype(result_type) if x.larray.numel() > 0: result_value = torch.equal(x.larray, y.larray) else: result_value = True return x.comm.allreduce(result_value, MPI.LAND)
[docs] def ge(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray: """ Returns a D:class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich greater than or equal comparison between values from operand ``x`` with respect to values of operand ``y`` (i.e. ``x>=y``), not commutative. Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument. Parameters ---------- x: DNDarray or scalar The first operand to be compared greater than or equal to second operand y: DNDarray or scalar The second operand to be compared less than or equal to first operand Examples -------- >>> import heat as ht >>> x = ht.float32([[1, 2], [3, 4]]) >>> ht.ge(x, 3.0) DNDarray([[False, False], [ True, True]], dtype=ht.bool, device=cpu:0, split=None) >>> y = ht.float32([[2, 2], [2, 2]]) >>> ht.ge(x, y) DNDarray([[False, True], [ True, True]], dtype=ht.bool, device=cpu:0, split=None) """ res = _operations.__binary_op(torch.ge, x, y) if res.dtype != types.bool: res = dndarray.DNDarray( res.larray.type(torch.bool), res.gshape, types.bool, res.split, res.device, res.comm, res.balanced, ) return res
DNDarray.__ge__ = lambda self, other: ge(self, other) DNDarray.__ge__.__doc__ = ge.__doc__ # alias greater_equal = ge greater_equal.__doc__ = ge.__doc__
[docs] def gt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray: """ Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich greater than comparison between values from operand ``x`` with respect to values of operand ``y`` (i.e. ``x>y``), not commutative. Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument. Parameters ---------- x: DNDarray or scalar The first operand to be compared greater than second operand y: DNDarray or scalar The second operand to be compared less than first operand Examples -------- >>> import heat as ht >>> x = ht.float32([[1, 2], [3, 4]]) >>> ht.gt(x, 3.0) DNDarray([[False, False], [False, True]], dtype=ht.bool, device=cpu:0, split=None) >>> y = ht.float32([[2, 2], [2, 2]]) >>> ht.gt(x, y) DNDarray([[False, False], [ True, True]], dtype=ht.bool, device=cpu:0, split=None) """ res = _operations.__binary_op(torch.gt, x, y) if res.dtype != types.bool: res = dndarray.DNDarray( res.larray.type(torch.bool), res.gshape, types.bool, res.split, res.device, res.comm, res.balanced, ) return res
DNDarray.__gt__ = lambda self, other: gt(self, other) DNDarray.__gt__.__doc__ = gt.__doc__ # alias greater = gt greater.__doc__ = gt.__doc__
[docs] def le(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray: """ Return a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich less than or equal comparison between values from operand ``x`` with respect to values of operand ``y`` (i.e. ``x<=y``), not commutative. Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument. Parameters ---------- x: DNDarray or scalar The first operand to be compared less than or equal to second operand y: DNDarray or scalar The second operand to be compared greater than or equal to first operand Examples -------- >>> import heat as ht >>> x = ht.float32([[1, 2], [3, 4]]) >>> ht.le(x, 3.0) DNDarray([[ True, True], [ True, False]], dtype=ht.bool, device=cpu:0, split=None) >>> y = ht.float32([[2, 2], [2, 2]]) >>> ht.le(x, y) DNDarray([[ True, True], [False, False]], dtype=ht.bool, device=cpu:0, split=None) """ res = _operations.__binary_op(torch.le, x, y) if res.dtype != types.bool: res = dndarray.DNDarray( res.larray.type(torch.bool), res.gshape, types.bool, res.split, res.device, res.comm, res.balanced, ) return res
DNDarray.__le__ = lambda self, other: le(self, other) DNDarray.__le__.__doc__ = le.__doc__ # alias less_equal = le less_equal.__doc__ = le.__doc__
[docs] def lt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray: """ Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich less than comparison between values from operand ``x`` with respect to values of operand ``y`` (i.e. ``x<y``), not commutative. Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument. Parameters ---------- x: DNDarray or scalar The first operand to be compared less than second operand y: DNDarray or scalar The second operand to be compared greater than first operand Examples -------- >>> import heat as ht >>> x = ht.float32([[1, 2], [3, 4]]) >>> ht.lt(x, 3.0) DNDarray([[ True, True], [False, False]], dtype=ht.bool, device=cpu:0, split=None) >>> y = ht.float32([[2, 2], [2, 2]]) >>> ht.lt(x, y) DNDarray([[ True, False], [False, False]], dtype=ht.bool, device=cpu:0, split=None) """ res = _operations.__binary_op(torch.lt, x, y) if res.dtype != types.bool: res = dndarray.DNDarray( res.larray.type(torch.bool), res.gshape, types.bool, res.split, res.device, res.comm, res.balanced, ) return res
DNDarray.__lt__ = lambda self, other: lt(self, other) DNDarray.__lt__.__doc__ = lt.__doc__ # alias less = lt less.__doc__ = lt.__doc__
[docs] def ne(x, y) -> DNDarray: """ Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich comparison of non-equality between values from two operands, commutative. Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be compared as argument. Returns True if the operands are not scalars or :class:`~heat.core.dndarray.DNDarray` Parameters ---------- x: DNDarray or scalar The first operand involved in the comparison y: DNDarray or scalar The second operand involved in the comparison Examples -------- >>> import heat as ht >>> x = ht.float32([[1, 2], [3, 4]]) >>> ht.ne(x, 3.0) DNDarray([[ True, True], [False, True]], dtype=ht.bool, device=cpu:0, split=None) >>> y = ht.float32([[2, 2], [2, 2]]) >>> ht.ne(x, y) DNDarray([[ True, False], [ True, True]], dtype=ht.bool, device=cpu:0, split=None) >>> ht.ne(x, slice(None)) True """ try: res = _operations.__binary_op(torch.ne, x, y) if res.dtype != types.bool: res = dndarray.DNDarray( res.larray.type(torch.bool), res.gshape, types.bool, res.split, res.device, res.comm, res.balanced, ) return res except (TypeError, ValueError): return True
DNDarray.__ne__ = lambda self, other: ne(self, other) DNDarray.__ne__.__doc__ = ne.__doc__ # alias not_equal = ne not_equal.__doc__ = ne.__doc__