heat.nn.data_parallel
This file is for the general data parallel neural network classes.
Module Contents
- class DataParallel(module: torch.nn.Module, comm: heat.core.communication.MPICommunication, optimizer: heat.optim.DataParallelOptimizer | List | Tuple, blocking_parameter_updates: bool = False)
Bases:
torch.nn.Module
Implements data parallelism across multiple processes. This means that the same model will be run locally on each process. Creation of the model is similar to PyTorch, the only changes are using HeAT layers (ht.nn.layer) in the initialization of the network/optimizer. If there is not a HeAT layer, it will fall back to the PyTorch layer of the same name. The same is true for the optimizer. It’s possible to use more than one optimizer, but communication during parameter updates is limited to blocking. The same limitation takes effect when passing an optimizer that does not deal exactly with the set of model’s parameters. For the given model both the
__init__()
andforward()
functions must be defined in the class defining the network.An example of this is shown in examples/mnist.py.
It is highly recommended that a HeAT DataLoader is used, see
ht.utils.data.DataLoader
. The default communications scheme for this is blocking. The blocking scheme will average the model parameters during the backwards step, synchronizing them before the next model iteration.Usage of more than one optimizer forces MPI communication to be parameter updates to use blocking communications.
- Variables:
module (torch.nn.Module) – The local module
comm (MPICommunication) – Communicator to use
optimizer (heat.DataParallelOptimizer, List, Tuple) – Individual or sequence of DataParallelOptimizers to be used
blocking_parameter_updates (bool, optional) – Flag indicating the usage of blocking communications for parameter updates Default: non-blocking updates (
False
)
- __setattr__(name: str, value: torch.nn.Module | torch.Tensor | Any) None
Overwrite the current torch.nn.Module.__setattr__ so that it auto-detects the end of epoch’s training phase and finalize wait handles (only relevant for non-blocking)
- forward(*inputs: tuple, **kwargs: dict) torch.Tensor
Do the forward step for the network, receive the parameters from the last
- _iparam_update(param_slice: slice = None, layer_names: List[str] = None) None
Update parameters asynchronously via wait handles.
- Parameters:
param_slice (slice, optional) – Slice object for creating a view onto optimizer’s params list.n By default, the whole params list is used, (
None
)layer_names (list(str), optional) – List of layer names which parameters will be updated, must match param_slice.n By default, all layers are updated (
None
)
- _blocking_hook(grad_loc: torch.Tensor) torch.Tensor
Add a blocking hook to the PyTorch DAG for all of the backwards calls.
- Parameters:
grad_loc (torch.Tensor) – The local gradient
References
[1] (cf. https://pytorch.org/docs/stable/tensors.html#torch.Tensor.register_hook).
- _nonblocking_hook(layer_name: str, param_name: str) Callable
Add a nonblocking hook to send and receive the averaged parameters after the backwards step
- Parameters:
layer_name (str) – Name of the layer
param_name (str) – Name of the parameter
- _forward_hook(layer_name: str) Callable
Add a forward hook to update parameters during the forward step. This will return a hook with can be added using the
submodule.register_forward_pre_hook
command.- Parameters:
layer_name (str) – Name of the layer
- _reset_parameters(module: torch.nn.Module) None
Reset parameters of given torch submodule. Only works for basic module types containing
reset_parameters
function.- Parameters:
module (torch.nn.Module) – Submodule whose parameters are to be reset
- class DataParallelMultiGPU(module: torch.nn.Module, optimizer: heat.optim.DASO, comm: heat.core.communication.MPICommunication = MPI_WORLD)
Bases:
torch.nn.Module
This creates data parallel networks local to each node using PyTorch’s distributed class. This does NOT do any global synchronizations. To make optimal use of this structure, use
ht.optim.DASO
.Notes
The PyTorch distributed process group must already exist before this class is initialized.
- Parameters:
module (torch.nn.Module) – an implemented PyTorch model
optimizer (optim.DASO) – A DASO optimizer. Other optimizers are not yet implemented. The DASO optimizer should be defined prior to calling this class.
comm (MPICommunication, optional) – A global communicator. Default:
MPICommunication
- forward(*inputs: Tuple, **kwargs: Dict) torch.Tensor
Calls the forward method for the torch model
- _reset_parameters(module: torch.nn.Module) None
Reset parameters of given torch submodule. Only works for basic module types containing
reset_parameters
function.- Parameters:
module (torch.nn.Module) – Submodule whose parameters are to be reset