heat.sanitation
Collection of validation/sanitation routines.
Module Contents
- sanitize_distribution(*args: heat.core.dndarray.DNDarray, target: heat.core.dndarray.DNDarray, diff_map: torch.Tensor = None) heat.core.dndarray.DNDarray | Tuple(DNDarray)[source]
Distribute every arg according to target.lshape_map or, if provided, diff_map. After this sanitation, the lshapes are compatible along the split dimension. Args can contain non-distributed DNDarrays, they will be split afterwards, if target is split.
- Parameters:
args (DNDarray) – Dndarrays to be distributed
target (DNDarray) – Dndarray used to sanitize the metadata and to, if diff_map is not given, determine the resulting distribution.
diff_map (torch.Tensor (optional)) – Different lshape_map. Overwrites the distribution of the target array. Used in cases when the target array does not correspond to the actually wanted distribution, e.g. because it only contains a single element along the split axis and gets broadcast.
- Raises:
TypeError – When an argument is not a
DNDarrayorNone.ValueError – When the split-axes or sizes along the split-axis do not match.
See also
create_lshape_map()Function to create the lshape_map.
- sanitize_in(x: Any)[source]
Verify that input object is
DNDarray.- Parameters:
x (Any) – Input object
- Raises:
TypeError – When
xis not aDNDarray.
- sanitize_infinity(x: heat.core.dndarray.DNDarray | torch.Tensor) int | float[source]
Returns largest possible value for the
dtypeof the input array.- Parameters:
x (Union[DNDarray, torch.Tensor]) – Input object.
- sanitize_in_tensor(x: Any)[source]
Verify that input object is
torch.Tensor.- Parameters:
x (Any) – Input object.
- Raises:
TypeError – When
xis not atorch.Tensor.
- sanitize_lshape(array: heat.core.dndarray.DNDarray, tensor: torch.Tensor)[source]
Verify shape consistency when manipulating process-local arrays.
- Parameters:
array (DNDarray) – the original, potentially distributed
DNDarraytensor (torch.Tensor) – process-local data meant to replace
array.larray
- Raises:
ValueError – if shape of local
torch.Tensoris inconsistent with globalDNDarray.
- sanitize_out(out: heat.core.dndarray.DNDarray, output_shape: Tuple, output_split: int, output_device: str, output_comm: heat.core.communication.Communication = None)[source]
Validate output buffer
out.- Parameters:
out (DNDarray) – the out buffer where the result of some operation will be stored
output_shape (Tuple) – the calculated shape returned by the operation
output_split (Int) – the calculated split axis returned by the operation
output_device (Str) – “cpu” or “gpu” as per location of data
output_comm (Communication) – Communication object of the result of the operation
- Raises:
TypeError – if
outis not aDNDarray.ValueError – if shape, split direction, or device of the output buffer
outdo not match the operation result.
- sanitize_sequence(seq: Sequence[int, Ellipsis] | Sequence[float, Ellipsis] | heat.core.dndarray.DNDarray | torch.Tensor) List[source]
Check if sequence is valid, return list.
- Parameters:
seq (Union[Sequence[int, ...], Sequence[float, ...], DNDarray, torch.Tensor]) – Input sequence.
- Raises:
TypeError – if
seqis neither a list nor a tuple
- scalar_to_1d(x: heat.core.dndarray.DNDarray) heat.core.dndarray.DNDarray[source]
Turn a scalar
DNDarrayinto a 1-DDNDarraywith 1 element.- Parameters:
x (DNDarray) – with x.ndim = 0