Source code for heat.array_api._manipulation_functions

from __future__ import annotations

from ._array_object import Array
from ._data_type_functions import result_type

from typing import Optional, Tuple, Union, List

import heat as ht


[docs] def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: """ Joins a sequence of arrays along an existing axis. Parameters ---------- arrays : Union[Tuple[Array, ...], List[Array]] Input arrays to join. The arrays must have the same shape, except in the dimension specified by ``axis``. axis : Optional[int] Axis along which the arrays will be joined. If ``axis`` is ``None``, arrays are flattened before concatenation. Default: ``0``. """ result_type(*arrays) arrays = tuple(a._array for a in arrays) if axis is None: arrays = tuple(ht.flatten(a) for a in arrays) axis = 0 return Array._new(ht.concatenate(arrays, axis=axis))
[docs] def expand_dims(x: Array, /, *, axis: int = 0) -> Array: """ Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``. Parameters ---------- x : Array Input array. axis : int Axis position (zero-based). If ``x`` has rank (i.e, number of dimensions) ``N``, a valid ``axis`` must reside on the closed-interval ``[-N-1, N]``. """ if axis < -x.ndim - 1 or axis > x.ndim: raise IndexError("Invalid axis") return Array._new(ht.expand_dims(x._array, axis))
[docs] def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: """ Reverses the order of elements in an array along the given ``axis``. The shape of the array is preserved. Parameters ---------- x : Array Input array. axis : Optional[Union[int, Tuple[int, ...]]] Axis (or axes) along which to flip. If ``axis`` is ``None``, the function flips all input array axes. """ return Array._new(ht.flip(x._array, axis=axis))
[docs] def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: """ Permutes the axes (dimensions) of an array ``x``. Parameters ---------- x : Array Input array. axes : Tuple[int, ...] Tuple containing a permutation of ``(0, 1, ..., N-1)`` where ``N`` is the number of axes (dimensions) of ``x``. """ return Array._new(ht.transpose(x._array, list(axes)))
[docs] def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array: """ Reshapes an array without changing its data. Parameters ---------- x : Array Input array to reshape. shape : Tuple[int, ...] A new shape compatible with the original shape. One shape dimension is allowed to be ``-1``. When a shape dimension is ``-1``, the corresponding output array shape dimension is inferred from the length of the array and the remaining dimensions. copy : Optional[bool] Boolean indicating whether or not to copy the input array. """ res = ht.reshape(x._array, shape) if not copy: x._array = res return x return Array._new(res)
[docs] def roll( x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, ) -> Array: """ Rolls array elements along a specified axis. Array elements that roll beyond the last position are re-introduced at the first position. Array elements that roll beyond the first position are re-introduced at the last position. Parameters ---------- x : Array Input array. shift : Union[int, Tuple[int, ...]] Number of places by which the elements are shifted. If ``shift`` is a tuple, then ``axis`` must be a tuple of the same size, and each of the given axes is shifted by the corresponding element in ``shift``. If ``shift`` is an ``int`` and ``axis`` a tuple, then the same ``shift`` is used for all specified axes. axis : Optional[Union[int, Tuple[int, ...]]] Axis (or axes) along which elements to shift. If ``axis`` is ``None``, the array is flattened, shifted, and then restored to its original shape. Default: ``None``. """ return Array._new(ht.roll(x._array, shift, axis=axis))
[docs] def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: """ Removes singleton dimensions (axes) from ``x``. Parameters ---------- x : Array Input array. axis : Union[int, Tuple[int, ...]] Axis (or axes) to squeeze. Raises ------ ``ValueError``, if an axis is selected with shape entry greater than one. """ return Array._new(ht.squeeze(x._array, axis=axis))
[docs] def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: """ Joins a sequence of arrays along a new axis. Parameters ---------- arrays : Union[Tuple[array, ...], List[array]] Input arrays to join. Each array must have the same shape. axis : int Axis along which the arrays will be joined. """ # Call result type here just to raise on disallowed type combinations result_type(*arrays) arrays = tuple(a._array for a in arrays) return Array._new(ht.stack(arrays, axis=axis))