from __future__ import annotations
from ._array_object import Array
from ._dtypes import _result_type, _numeric_dtypes
from typing import Optional, Tuple
import heat as ht
[docs]
def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
"""
Returns the indices of the maximum values along a specified axis. When the
maximum value occurs multiple times, only the indices corresponding to the
first occurrence are returned.
Parameters
----------
x : Array
Input array. Must have a numeric data type.
axis : Optional[int]
Axis along which to search. If ``None``, the function returns the index of
the maximum value of the flattened array. Default: ``None``.
keepdims : bool
If ``True``, the reduced axes (dimensions) are included in the result as
singleton dimensions. Otherwise, if ``False``, the reduced axes (dimensions)
are not be included in the result. Default: ``False``.
"""
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in argmax")
res = ht.argmax(x._array, axis=axis, keepdim=keepdims)
return Array._new(res)
[docs]
def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
"""
Returns the indices of the minimum values along a specified axis. When the
minimum value occurs multiple times, only the indices corresponding to the
first occurrence are returned.
Parameters
----------
x : Array
Input array. Must have a numeric data type.
axis : Optional[int]
Axis along which to search. If ``None``, the function returns the index of
the minimum value of the flattened array. Default: ``None``.
keepdims : bool
If ``True``, the reduced axes (dimensions) are included in the result as
singleton dimensions. Otherwise, if ``False``, the reduced axes (dimensions)
are not be included in the result. Default: ``False``.
"""
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in argmin")
res = ht.argmin(x._array, axis=axis, keepdim=keepdims)
return Array._new(res)
[docs]
def nonzero(x: Array, /) -> Tuple[Array, ...]:
"""
Returns the indices of the array elements which are non-zero.
Parameters
----------
x : Array
Input array. Must have a positive rank.
"""
# See PR #914 for overhaul
return tuple(Array._new(i) for i in ht.nonzero(x._array))
[docs]
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
"""
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
Parameters
----------
condition : Array
When ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Must be
compatible with ``x1`` and ``x2``.
x1 : Array
First input array. Must be compatible with ``condition`` and ``x2``.
x2 : Array
Second input array. Must be compatible with ``condition`` and ``x1``.
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(ht.where(condition._array, x1._array, x2._array))