Source code for heat.core.stride_tricks

"""
A collection of functions used for inferring or correcting things before major computation
"""

import itertools
import numpy as np
import torch

from typing import Tuple, Union


[docs] def broadcast_shape(shape_a: Tuple[int, ...], shape_b: Tuple[int, ...]) -> Tuple[int, ...]: """ Infers, if possible, the broadcast output shape of two operands a and b. Inspired by stackoverflow post: https://stackoverflow.com/questions/24743753/test-if-an-array-is-broadcastable-to-a-shape Parameters ---------- shape_a : Tuple[int,...] Shape of first operand shape_b : Tuple[int,...] Shape of second operand Raises ------ ValueError If the two shapes cannot be broadcast. Examples -------- >>> import heat as ht >>> ht.core.stride_tricks.broadcast_shape((5, 4), (4,)) (5, 4) >>> ht.core.stride_tricks.broadcast_shape((1, 100, 1), (10, 1, 5)) (10, 100, 5) >>> ht.core.stride_tricks.broadcast_shape( ... (8, 1, 6, 1), ... ( ... 7, ... 1, ... 5, ... ), ... ) (8,7,6,5)) >>> ht.core.stride_tricks.broadcast_shape((2, 1), (8, 4, 3)) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "heat/core/stride_tricks.py", line 42, in broadcast_shape "operands could not be broadcast, input shapes {} {}".format(shape_a, shape_b) ValueError: operands could not be broadcast, input shapes (2, 1) (8, 4, 3) """ try: resulting_shape = torch.broadcast_shapes(shape_a, shape_b) except TypeError: raise TypeError(f"operand 1 must be tuple of ints, not {type(shape_a)}") except NameError: raise TypeError(f"operands must be tuples of ints, not {shape_a} and {shape_b}") except RuntimeError: raise ValueError(f"operands could not be broadcast, input shapes {shape_a} {shape_b}") return tuple(resulting_shape)
[docs] def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: """ Infers, if possible, the broadcast output shape of multiple operands. Parameters ---------- *shapes : Tuple[int,...] Shapes of operands. Returns ------- Tuple[int, ...] The broadcast output shape. Raises ------ ValueError If the shapes cannot be broadcast. Examples -------- >>> import heat as ht >>> ht.broadcast_shapes((5, 4), (4,)) (5, 4) >>> ht.broadcast_shapes((1, 100, 1), (10, 1, 5)) (10, 100, 5) >>> ht.broadcast_shapes( ... (8, 1, 6, 1), ... ( ... 7, ... 1, ... 5, ... ), ... ) (8,7,6,5)) >>> ht.broadcast_shapes((2, 1), (8, 4, 3)) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "heat/core/stride_tricks.py", line 100, in broadcast_shapes "operands could not be broadcast, input shapes {}".format(shapes)) ValueError: operands could not be broadcast, input shapes ((2, 1), (8, 4, 3)) """ try: resulting_shape = torch.broadcast_shapes(*shapes) except TypeError: raise TypeError(f"operands must be tuples of ints, not {shapes}") except RuntimeError: raise ValueError(f"operands could not be broadcast, input shapes {shapes}") return tuple(resulting_shape)
[docs] def sanitize_axis( shape: Tuple[int, ...], axis: Union[int, None, Tuple[int, ...]] ) -> Union[int, None, Tuple[int, ...]]: """ Checks conformity of an axis with respect to a given shape. The axis will be converted to its positive equivalent and is checked to be within bounds Parameters ---------- shape : Tuple[int, ...] Shape of an array axis : ints or Tuple[int, ...] or None The axis to be sanitized Raises ------ ValueError if the axis cannot be sanitized, i.e. out of bounds. TypeError if the axis is not integral. Examples -------- >>> import heat as ht >>> ht.core.stride_tricks.sanitize_axis((5, 4, 4), 1) 1 >>> ht.core.stride_tricks.sanitize_axis((5, 4, 4), -1) 2 >>> ht.core.stride_tricks.sanitize_axis((5, 4), (1,)) (1,) >>> ht.core.stride_tricks.sanitize_axis((5, 4), 1.0) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "heat/heat/core/stride_tricks.py", line 99, in sanitize_axis raise TypeError("axis must be None or int or tuple, but was {}".format(type(axis))) TypeError: axis must be None or int or tuple, but was <class 'float'> """ # scalars are handled like unsplit matrices original_axis = axis ndim = len(shape) if ndim == 0: axis = None if axis is not None and not isinstance(axis, int) and not isinstance(axis, tuple): raise TypeError(f"axis must be None or int or tuple, but was {type(axis)}") if isinstance(axis, tuple): axis = tuple(dim + len(shape) if dim < 0 else dim for dim in axis) for dim in axis: if dim < 0 or dim >= len(shape): raise ValueError( f"axis {original_axis} is out of bounds for {ndim}-dimensional array" ) return axis if axis is None or 0 <= axis < len(shape): return axis elif axis < 0: axis += len(shape) if axis < 0 or axis >= len(shape): raise ValueError(f"axis {original_axis} is out of bounds for {ndim}-dimensional array") return axis
[docs] def sanitize_shape(shape: Union[int, Tuple[int, ...]], lval: int = 0) -> Tuple[int, ...]: """ Verifies and normalizes the given shape. Parameters ---------- shape : int or Tupe[int,...] Shape of an array. lval : int Lowest legal value Raises ------ ValueError If the shape contains illegal values, e.g. negative numbers. TypeError If the given shape is neither and int or a sequence of ints. Examples -------- >>> import heat as ht >>> ht.core.stride_tricks.sanitize_shape(3) (3,) >>> ht.core.stride_tricks.sanitize_shape([1, 2, 3]) (1, 2, 3,) >>> ht.core.stride_tricks.sanitize_shape(1.0) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "heat/heat/core/stride_tricks.py", line 159, in sanitize_shape raise TypeError("expected sequence object with length >= 0 or a single integer") TypeError: expected sequence object with length >= 0 or a single integer """ shape = tuple(shape) if hasattr(shape, "__iter__") else (shape,) for dimension in shape: if issubclass(type(dimension), np.integer): dimension = int(dimension) if not isinstance(dimension, int): raise TypeError("expected sequence object with length >= 0 or a single integer") if dimension < lval: raise ValueError("negative dimensions are not allowed") return shape
[docs] def sanitize_slice(sl: slice, max_dim: int) -> slice: """ Remove None-types from a slice Parameters ---------- sl : slice slice to adjust max_dim : int maximum index for the given slice Raises ------ TypeError if sl is not a slice. """ if not isinstance(sl, slice): raise TypeError("This function is only for slices!") new_sl = [None] * 3 new_sl[0] = 0 if sl.start is None else sl.start if new_sl[0] < 0: new_sl[0] += max_dim new_sl[1] = max_dim if sl.stop is None else sl.stop if new_sl[1] < 0: new_sl[1] += max_dim new_sl[2] = 1 if sl.step is None else sl.step return slice(new_sl[0], new_sl[1], new_sl[2])