heat.nn.data_parallel
General data parallel neural network classes.
Module Contents
- class DataParallel(module: torch.nn.Module, comm: heat.core.communication.MPICommunication, optimizer: heat.optim.DataParallelOptimizer | List | Tuple, blocking_parameter_updates: bool = False)[source]
Bases:
torch.nn.ModuleImplements data parallelism across multiple processes. This means that the same model will be run locally on each process. Creation of the model is similar to PyTorch, the only changes are using HeAT layers (ht.nn.layer) in the initialization of the network/optimizer. If there is not a HeAT layer, it will fall back to the PyTorch layer of the same name. The same is true for the optimizer. It’s possible to use more than one optimizer, but communication during parameter updates is limited to blocking. The same limitation takes effect when passing an optimizer that does not deal exactly with the set of model’s parameters. For the given model both the
__init__()andforward()functions must be defined in the class defining the network.An example of this is shown in examples/mnist.py.
It is highly recommended that a HeAT DataLoader is used, see
ht.utils.data.DataLoader. The default communications scheme for this is blocking. The blocking scheme will average the model parameters during the backwards step, synchronizing them before the next model iteration.Usage of more than one optimizer forces MPI communication to be parameter updates to use blocking communications.
- Variables:
module (torch.nn.Module) – The local module
comm (MPICommunication) – Communicator to use
optimizer (heat.DataParallelOptimizer, List, Tuple) – Individual or sequence of DataParallelOptimizers to be used
blocking_parameter_updates (bool, optional) – Flag indicating the usage of blocking communications for parameter updates Default: non-blocking updates (
False)
- module
- comm
- blocking_parameter_updates = False
- _dp_optimizers = []
- _layer_wait_handles
- _fwd_hook_handles = []
- _active_layers
- _param_slices
- _param_indices
- __setattr__(name: str, value: torch.nn.Module | torch.Tensor | Any) None[source]
Overwrite the current torch.nn.Module.__setattr__ so that it auto-detects the end of epoch’s training phase and finalize wait handles (only relevant for non-blocking)
- forward(*inputs: tuple, **kwargs: dict) torch.Tensor[source]
Do the forward step for the network, receive the parameters from the last
- _iparam_update(param_slice: slice = None, layer_names: List[str] = None) None[source]
Update parameters asynchronously via wait handles.
- Parameters:
param_slice (slice, optional) – Slice object for creating a view onto optimizer’s params list.n By default, the whole params list is used, (
None)layer_names (list(str), optional) – List of layer names which parameters will be updated, must match param_slice.n By default, all layers are updated (
None)
- _blocking_hook(grad_loc: torch.Tensor) torch.Tensor[source]
Add a blocking hook to the PyTorch DAG for all of the backwards calls.
- Parameters:
grad_loc (torch.Tensor) – The local gradient
References
[1] (cf. https://pytorch.org/docs/stable/tensors.html#torch.Tensor.register_hook).
- _nonblocking_hook(layer_name: str, param_name: str) Callable[source]
Add a nonblocking hook to send and receive the averaged parameters after the backwards step
- Parameters:
layer_name (str) – Name of the layer
param_name (str) – Name of the parameter
- class DataParallelMultiGPU(module: torch.nn.Module, optimizer: heat.optim.DASO, comm: heat.core.communication.MPICommunication = MPI_WORLD)[source]
Bases:
torch.nn.ModuleCreates data parallel networks local to each node using PyTorch’s distributed class. This does NOT do any global synchronizations. To make optimal use of this structure, use
ht.optim.DASO.Notes
The PyTorch distributed process group must already exist before this class is initialized.
- Parameters:
module (torch.nn.Module) – an implemented PyTorch model
optimizer (optim.DASO) – A DASO optimizer. Other optimizers are not yet implemented. The DASO optimizer should be defined prior to calling this class.
comm (MPICommunication, optional) – A global communicator. Default:
MPICommunication
- module
- comm