heat.stride_tricks

A collection of functions used for inferring or correcting things before major computation

Module Contents

broadcast_shape(shape_a: Tuple[int, Ellipsis], shape_b: Tuple[int, Ellipsis]) Tuple[int, Ellipsis]

Infers, if possible, the broadcast output shape of two operands a and b. Inspired by stackoverflow post: https://stackoverflow.com/questions/24743753/test-if-an-array-is-broadcastable-to-a-shape

Parameters:
  • shape_a (Tuple[int,...]) – Shape of first operand

  • shape_b (Tuple[int,...]) – Shape of second operand

Raises:

ValueError – If the two shapes cannot be broadcast.

Examples

>>> import heat as ht
>>> ht.core.stride_tricks.broadcast_shape((5,4),(4,))
(5, 4)
>>> ht.core.stride_tricks.broadcast_shape((1,100,1),(10,1,5))
(10, 100, 5)
>>> ht.core.stride_tricks.broadcast_shape((8,1,6,1),(7,1,5,))
(8,7,6,5))
>>> ht.core.stride_tricks.broadcast_shape((2,1),(8,4,3))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "heat/core/stride_tricks.py", line 42, in broadcast_shape
    "operands could not be broadcast, input shapes {} {}".format(shape_a, shape_b)
ValueError: operands could not be broadcast, input shapes (2, 1) (8, 4, 3)
broadcast_shapes(*shapes: Tuple[int, Ellipsis]) Tuple[int, Ellipsis]

Infers, if possible, the broadcast output shape of multiple operands.

Parameters:

*shapes (Tuple[int,...]) – Shapes of operands.

Returns:

The broadcast output shape.

Return type:

Tuple[int, …]

Raises:

ValueError – If the shapes cannot be broadcast.

Examples

>>> import heat as ht
>>> ht.broadcast_shapes((5,4),(4,))
(5, 4)
>>> ht.broadcast_shapes((1,100,1),(10,1,5))
(10, 100, 5)
>>> ht.broadcast_shapes((8,1,6,1),(7,1,5,))
(8,7,6,5))
>>> ht.broadcast_shapes((2,1),(8,4,3))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "heat/core/stride_tricks.py", line 100, in broadcast_shapes
    "operands could not be broadcast, input shapes {}".format(shapes))
ValueError: operands could not be broadcast, input shapes ((2, 1), (8, 4, 3))
sanitize_axis(shape: Tuple[int, Ellipsis], axis: int | None | Tuple[int, Ellipsis]) int | None | Tuple[int, Ellipsis]

Checks conformity of an axis with respect to a given shape. The axis will be converted to its positive equivalent and is checked to be within bounds

Parameters:
  • shape (Tuple[int, ...]) – Shape of an array

  • axis (ints or Tuple[int, ...] or None) – The axis to be sanitized

Raises:
  • ValueError – if the axis cannot be sanitized, i.e. out of bounds.

  • TypeError – if the axis is not integral.

Examples

>>> import heat as ht
>>> ht.core.stride_tricks.sanitize_axis((5,4,4),1)
1
>>> ht.core.stride_tricks.sanitize_axis((5,4,4),-1)
2
>>> ht.core.stride_tricks.sanitize_axis((5, 4), (1,))
(1,)
>>> ht.core.stride_tricks.sanitize_axis((5, 4), 1.0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "heat/heat/core/stride_tricks.py", line 99, in sanitize_axis
    raise TypeError("axis must be None or int or tuple, but was {}".format(type(axis)))
TypeError: axis must be None or int or tuple, but was <class 'float'>
sanitize_shape(shape: int | Tuple[int, Ellipsis], lval: int = 0) Tuple[int, Ellipsis]

Verifies and normalizes the given shape.

Parameters:
  • shape (int or Tupe[int,...]) – Shape of an array.

  • lval (int) – Lowest legal value

Raises:
  • ValueError – If the shape contains illegal values, e.g. negative numbers.

  • TypeError – If the given shape is neither and int or a sequence of ints.

Examples

>>> import heat as ht
>>> ht.core.stride_tricks.sanitize_shape(3)
(3,)
>>> ht.core.stride_tricks.sanitize_shape([1, 2, 3])
(1, 2, 3,)
>>> ht.core.stride_tricks.sanitize_shape(1.0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "heat/heat/core/stride_tricks.py", line 159, in sanitize_shape
    raise TypeError("expected sequence object with length >= 0 or a single integer")
TypeError: expected sequence object with length >= 0 or a single integer
sanitize_slice(sl: slice, max_dim: int) slice

Remove None-types from a slice

Parameters:
  • sl (slice) – slice to adjust

  • max_dim (int) – maximum index for the given slice

Raises:

TypeError – if sl is not a slice.