Source code for heat.core.sanitation

"""
Collection of validation/sanitation routines.
"""

from __future__ import annotations

import numpy as np
import torch
import warnings
from typing import Any, Union, Sequence, List, Tuple

from .communication import MPI, Communication
from .dndarray import DNDarray

from . import factories
from . import stride_tricks
from . import types


__all__ = [
    "sanitize_distribution",
    "sanitize_in",
    "sanitize_infinity",
    "sanitize_in_tensor",
    "sanitize_lshape",
    "sanitize_out",
    "sanitize_sequence",
    "scalar_to_1d",
    "sanitize_in_min_max_nd",
]


[docs] def sanitize_distribution( *args: DNDarray, target: DNDarray, diff_map: torch.Tensor = None ) -> Union[DNDarray, Tuple(DNDarray)]: """ Distribute every `arg` according to `target.lshape_map` or, if provided, `diff_map`. After this sanitation, the lshapes are compatible along the split dimension. `Args` can contain non-distributed DNDarrays, they will be split afterwards, if `target` is split. Parameters ---------- args : DNDarray Dndarrays to be distributed target : DNDarray Dndarray used to sanitize the metadata and to, if diff_map is not given, determine the resulting distribution. diff_map : torch.Tensor (optional) Different lshape_map. Overwrites the distribution of the target array. Used in cases when the target array does not correspond to the actually wanted distribution, e.g. because it only contains a single element along the split axis and gets broadcast. Raises ------ TypeError When an argument is not a ``DNDarray`` or ``None``. ValueError When the split-axes or sizes along the split-axis do not match. See Also -------- :func:`~heat.core.dndarray.create_lshape_map` Function to create the lshape_map. """ out = [] sanitize_in(target) # early out on 1 process if target.comm.size == 1: for arg in args: sanitize_in(arg) out.append(arg) return tuple(out) if len(out) > 1 else out[0] target_split = target.split if diff_map is not None: sanitize_in_tensor(diff_map) target_map = diff_map if target_split is not None: tmap_split = target_map[:, target_split] target_size = tmap_split.sum().item() # Check if the diff_map is balanced w_size = target_map.shape[0] tmap_balanced = torch.full_like(tmap_split, fill_value=target_size // w_size) remainder = target_size % w_size tmap_balanced[:remainder] += 1 target_balanced = torch.equal(tmap_balanced, tmap_split) elif target_split is not None: target_map = target.lshape_map target_size = target.shape[target_split] target_balanced = target.is_balanced(force_check=False) for arg in args: sanitize_in(arg) if target.comm != arg.comm: try: raise NotImplementedError( f"Not implemented for other comms, found {target.comm.name} and {arg.comm.name}" ) except Exception: raise NotImplementedError("Not implemented for other comms") elif target_split is None: if arg.split is not None: raise NotImplementedError( f"DNDarrays must have the same split axes, found {target_split} and {arg.split}" ) else: out.append(arg) elif arg.shape[target_split] == 1 and target_size > 1: # broadcasting in split-dimension out.append(arg.resplit(None)) elif arg.shape[target_split] != target_size: raise ValueError( f"Cannot distribute to match in split dimension, shapes are {target.shape} and {arg.shape}" ) elif arg.split is None: # undistributed case if target_balanced: out.append( factories.array( arg, split=target_split, copy=False, comm=arg.comm, device=arg.device ) ) else: idx = [slice(None)] * arg.ndim idx[target_split] = slice( target_map[: arg.comm.rank, target_split].sum(), target_map[: arg.comm.rank + 1, target_split].sum(), ) out.append( factories.array( arg.larray[tuple(idx)], is_split=target_split, copy=False, comm=arg.comm, device=arg.device, ) ) elif arg.split != target_split: raise NotImplementedError( f"DNDarrays must have the same split axes, found {target_split} and {arg.split}" ) elif not ( # False target_balanced and arg.is_balanced(force_check=False) ): # Split axes are the same and atleast one is not balanced current_map = arg.lshape_map out_map = current_map.clone() out_map[:, target_split] = target_map[:, target_split] if not (current_map[:, target_split] == target_map[:, target_split]).all(): out.append(arg.redistribute(lshape_map=current_map, target_map=out_map)) else: out.append(arg) else: # both are balanced out.append(arg) if len(out) == 1: return out[0] return tuple(out)
[docs] def sanitize_in(x: Any): """ Verify that input object is ``DNDarray``. Parameters ---------- x : Any Input object Raises ------ TypeError When ``x`` is not a ``DNDarray``. """ if not isinstance(x, DNDarray): raise TypeError(f"Input must be a DNDarray, is {type(x)}")
[docs] def sanitize_in_min_max_nd( input: Any, min_nd: Union[int, None] = None, max_nd: Union[int, None] = None ) -> None: """ Verify that input object ``input`` is a ``DNDarray`` with at least ``min_nd`` or at most ``max_nd`` dimensions. If not raise ValueError If not DNDarray raise TypeError """ sanitize_in(input) if min_nd is not None and input.ndim < min_nd: raise ValueError(f"Input needs to be a DNDarray with at least {min_nd} dimensions.") if max_nd is not None and input.ndim > max_nd: raise ValueError(f"Input needs to be a DNDarray with at most {max_nd} dimensions.")
def sanitize_in_nd_realfloating(input: Any, inputname: str, allowed_ns: List[int]) -> None: """ Verify that input object ``input`` is a real floating point ``DNDarray`` with number of dimensions contained in ``allowed_ns``. The argument ``inputname`` is used for error messages. """ if not isinstance(input, DNDarray): raise TypeError(f"Argument {inputname} needs to be a DNDarray but is {type(input)}.") if input.ndim not in allowed_ns: raise ValueError( f"Argument {inputname} needs to be a {allowed_ns}-dimensional, but is {input.ndim}-dimensional." ) if not types.heat_type_is_realfloating(input.dtype): raise TypeError( f"Argument {inputname} needs to be a DNDarray with datatype float32 or float64, but data type is {input.dtype}." )
[docs] def sanitize_infinity(x: Union[DNDarray, torch.Tensor]) -> Union[int, float]: """ Returns largest possible value for the ``dtype`` of the input array. Parameters ---------- x: Union[DNDarray, torch.Tensor] Input object. """ dtype = x.dtype if isinstance(x, torch.Tensor) else x.larray.dtype try: largest = torch.finfo(dtype).max except TypeError: largest = torch.iinfo(dtype).max return largest
[docs] def sanitize_in_tensor(x: Any): """ Verify that input object is ``torch.Tensor``. Parameters ---------- x : Any Input object. Raises ------ TypeError When ``x`` is not a ``torch.Tensor``. """ if not isinstance(x, torch.Tensor): raise TypeError(f"Input must be a torch.Tensor, is {type(x)}")
[docs] def sanitize_lshape(array: DNDarray, tensor: torch.Tensor): """ Verify shape consistency when manipulating process-local arrays. Parameters ---------- array : DNDarray the original, potentially distributed ``DNDarray`` tensor : torch.Tensor process-local data meant to replace ``array.larray`` Raises ------ ValueError if shape of local ``torch.Tensor`` is inconsistent with global ``DNDarray``. """ # input sanitation is done in parent function tshape = tuple(tensor.shape) if tshape == array.lshape: return gshape = array.gshape split = array.split if split is None: # allow for axes with size 0, other axes must match non_zero = [i for i in range(len(tshape)) if tshape[i] != 0] cond = [tshape[i] == gshape[i] for i in non_zero].count(True) == len(non_zero) if cond: return else: raise ValueError( f"Shape of local tensor is inconsistent with global DNDarray: tensor.shape is {tshape}, should be {gshape}" ) # size of non-split dimensions must match global shape reduced_gshape = gshape[:split] + gshape[split + 1 :] reduced_tshape = tshape[:split] + tshape[split + 1 :] if reduced_tshape == reduced_gshape: return raise ValueError( f"Shape of local tensor along non-split axes is inconsistent with global DNDarray: tensor.shape is {tshape}, DNDarray is {gshape}" )
[docs] def sanitize_out( out: DNDarray, output_shape: Tuple, output_split: int, output_device: str, output_comm: Communication = None, ): """ Validate output buffer ``out``. Parameters ---------- out : DNDarray the `out` buffer where the result of some operation will be stored output_shape : Tuple the calculated shape returned by the operation output_split : Int the calculated split axis returned by the operation output_device : Str "cpu" or "gpu" as per location of data output_comm : Communication Communication object of the result of the operation Raises ------ TypeError if ``out`` is not a ``DNDarray``. ValueError if shape, split direction, or device of the output buffer ``out`` do not match the operation result. """ if not isinstance(out, DNDarray): raise TypeError(f"expected `out` to be None or a DNDarray, but was {type(out)}") control_values = [output_shape, output_split, output_device] out_values = [out.shape, out.split, out.device] # bulk-check if the two lists differ if control_values != out_values: # if so, check each value individually if output_shape != out.shape: raise ValueError(f"Expecting output buffer of shape {output_shape}, got {out.shape}") if output_split != out.split: raise ValueError(f"Expecting output buffer of split {output_split}, got {out.split}") if output_device != out.device: raise ValueError(f"Expecting output buffer on device {output_device}, got {out.device}") # check that communicators match if output_comm is not None and out.comm != output_comm: try: raise NotImplementedError( f"Not implemented for other comms, found {out.comm.name} and {output_comm.name}" ) except Exception: raise NotImplementedError("Not implemented for other comms")
[docs] def sanitize_sequence( seq: Union[Sequence[int, ...], Sequence[float, ...], DNDarray, torch.Tensor], ) -> List: """ Check if sequence is valid, return list. Parameters ---------- seq : Union[Sequence[int, ...], Sequence[float, ...], DNDarray, torch.Tensor] Input sequence. Raises ------ TypeError if ``seq`` is neither a list nor a tuple """ if isinstance(seq, list): return seq elif isinstance(seq, tuple): return list(seq) else: raise TypeError(f"seq must be a list or a tuple, got {type(seq)}")
[docs] def scalar_to_1d(x: DNDarray) -> DNDarray: """ Turn a scalar ``DNDarray`` into a 1-D ``DNDarray`` with 1 element. Parameters ---------- x : DNDarray with `x.ndim = 0` """ if x.ndim != 0: if x.ndim == 1 and x.gnumel == 1: return x raise ValueError( f"Input needs to be a scalar DNDarray,but was found to be {x.ndim}d DNDarray" ) return DNDarray( x.larray.unsqueeze(0), gshape=(1,), dtype=x.dtype, split=None, comm=x.comm, device=x.device, balanced=True, )