heat.nn.data_parallel

This file is for the general data parallel neural network classes.

Module Contents

class DataParallel(module: torch.nn.Module, comm: heat.core.communication.MPICommunication, optimizer: heat.optim.DataParallelOptimizer | List | Tuple, blocking_parameter_updates: bool = False)

Bases: torch.nn.Module

Implements data parallelism across multiple processes. This means that the same model will be run locally on each process. Creation of the model is similar to PyTorch, the only changes are using HeAT layers (ht.nn.layer) in the initialization of the network/optimizer. If there is not a HeAT layer, it will fall back to the PyTorch layer of the same name. The same is true for the optimizer. It’s possible to use more than one optimizer, but communication during parameter updates is limited to blocking. The same limitation takes effect when passing an optimizer that does not deal exactly with the set of model’s parameters. For the given model both the __init__() and forward() functions must be defined in the class defining the network.

An example of this is shown in examples/mnist.py.

It is highly recommended that a HeAT DataLoader is used, see ht.utils.data.DataLoader. The default communications scheme for this is blocking. The blocking scheme will average the model parameters during the backwards step, synchronizing them before the next model iteration.

Usage of more than one optimizer forces MPI communication to be parameter updates to use blocking communications.

Variables:
  • module (torch.nn.Module) – The local module

  • comm (MPICommunication) – Communicator to use

  • optimizer (heat.DataParallelOptimizer, List, Tuple) – Individual or sequence of DataParallelOptimizers to be used

  • blocking_parameter_updates (bool, optional) – Flag indicating the usage of blocking communications for parameter updates Default: non-blocking updates (False)

__setattr__(name: str, value: torch.nn.Module | torch.Tensor | Any) None

Overwrite the current torch.nn.Module.__setattr__ so that it auto-detects the end of epoch’s training phase and finalize wait handles (only relevant for non-blocking)

forward(*inputs: tuple, **kwargs: dict) torch.Tensor

Do the forward step for the network, receive the parameters from the last

_iparam_update(param_slice: slice = None, layer_names: List[str] = None) None

Update parameters asynchronously via wait handles.

Parameters:
  • param_slice (slice, optional) – Slice object for creating a view onto optimizer’s params list.n By default, the whole params list is used, (None)

  • layer_names (list(str), optional) – List of layer names which parameters will be updated, must match param_slice.n By default, all layers are updated (None)

_blocking_hook(grad_loc: torch.Tensor) torch.Tensor

Add a blocking hook to the PyTorch DAG for all of the backwards calls.

Parameters:

grad_loc (torch.Tensor) – The local gradient

References

[1] (cf. https://pytorch.org/docs/stable/tensors.html#torch.Tensor.register_hook).

_nonblocking_hook(layer_name: str, param_name: str) Callable

Add a nonblocking hook to send and receive the averaged parameters after the backwards step

Parameters:
  • layer_name (str) – Name of the layer

  • param_name (str) – Name of the parameter

_forward_hook(layer_name: str) Callable

Add a forward hook to update parameters during the forward step. This will return a hook with can be added using the submodule.register_forward_pre_hook command.

Parameters:

layer_name (str) – Name of the layer

_reset_parameters(module: torch.nn.Module) None

Reset parameters of given torch submodule. Only works for basic module types containing reset_parameters function.

Parameters:

module (torch.nn.Module) – Submodule whose parameters are to be reset

class DataParallelMultiGPU(module: torch.nn.Module, optimizer: heat.optim.DASO, comm: heat.core.communication.MPICommunication = MPI_WORLD)

Bases: torch.nn.Module

This creates data parallel networks local to each node using PyTorch’s distributed class. This does NOT do any global synchronizations. To make optimal use of this structure, use ht.optim.DASO.

Notes

The PyTorch distributed process group must already exist before this class is initialized.

Parameters:
  • module (torch.nn.Module) – an implemented PyTorch model

  • optimizer (optim.DASO) – A DASO optimizer. Other optimizers are not yet implemented. The DASO optimizer should be defined prior to calling this class.

  • comm (MPICommunication, optional) – A global communicator. Default: MPICommunication

forward(*inputs: Tuple, **kwargs: Dict) torch.Tensor

Calls the forward method for the torch model

_reset_parameters(module: torch.nn.Module) None

Reset parameters of given torch submodule. Only works for basic module types containing reset_parameters function.

Parameters:

module (torch.nn.Module) – Submodule whose parameters are to be reset