:mod:`heat.vmap` ===================== .. py:module:: heat.core.vmap .. autoapi-nested-parse:: Vmap module. This implements a functionality similar to PyTorchs vmap function. Requires PyTorch 2.0.0 or higher. Module Contents --------------- .. function:: vmap(func: Callable[[Tuple[torch.Tensor]], Tuple[torch.Tensor]], out_dims: Union[Tuple[int], int] = 0, randomness: str = 'error', *, chunk_size: int = None) -> Callable[[Tuple[heat.core.dndarray.DNDarray]], Tuple[heat.core.dndarray.DNDarray]] Apply a function to a DNDarray in a vectorized way. `heat.vmap` return a callable that can be applied to DNDarrays. Vectorization will automatically take place along the split axis/axes of the DNDarray(s); therefore, unlike in PyTorch, there is no argument `in_dims`. What we here refer to as "split axis/dimension" in the Heat terminology is often referred to as "batch axis/dimension" in the PyTorch terminology. :param func: The function to apply in a vmapped way to the DNDarray(s). It must take PyTorch tensor(s) as positional arguments. Additional parameters, not to be vmapped over, can be passed as keyword arguments. The callable returned by by `heat.vmap` will also accept these keyword arguments. :type func: callable :param out_dims: The dimensions of the output(s) that are mapped over; identical to the split dimension(s) of the output(s). Default is 0. :type out_dims: int or tuple of int, optional :param randomness: Determines how to handle randomness in the function to be vmapped. This argument is directly passed to the underlying PyTorch vmaps; see the corresponding PyTorch documentation for more information and the note below. If 'error' (default), an error is raised if the function to be mapped contains randomness. :type randomness: {'error', 'different', 'same'}, optional :param chunk_size: The size of the chunks to use for the process-local computation. If None (default), apply a single PyTorch vmap over the process-local chunks of data. If not None, then compute the process-local PyTorch vmap `chunk_size` many samples at a time. Note that `chunk_size=1` is equivalent to computing the process-local PyTorch vmap's with a for-loop. If you run into memory issues computing the vmap, please try a non-None chunk_size. :type chunk_size: int, optional .. note:: This function is a wrapper around PyTorch's `torch.vmap` function. In essence, a PyTorch vmap is applied to the input function `func` on each MPI process separately. This process-local PyTorch-vmapped function is then applied to the process-local chunks of the input DNDarray(s). Please note that the options 'same' and 'different' for `randomness` will result in behaviour different from the one known by PyTorch as (at least currently) no actions are taken to synchronize randomness across the MPI processes.