Source code for heat.nn.data_parallel

"""
General data parallel neural network classes.
"""

import warnings
import torch
import torch.distributed
import torch.nn as tnn

from collections import OrderedDict
from typing import Any, Callable, Dict, List, Union, Tuple

from .. import optim
from ..core.communication import MPI
from ..core.communication import MPI_WORLD
from ..core.communication import MPICommunication


__all__ = ["DataParallel", "DataParallelMultiGPU"]


[docs] class DataParallel(tnn.Module): """ Implements data parallelism across multiple processes. This means that the same model will be run locally on each process. Creation of the model is similar to PyTorch, the only changes are using HeAT layers (ht.nn.layer) in the initialization of the network/optimizer. If there is not a HeAT layer, it will fall back to the PyTorch layer of the same name. The same is true for the optimizer. It's possible to use more than one optimizer, but communication during parameter updates is limited to blocking. The same limitation takes effect when passing an optimizer that does not deal exactly with the set of model's parameters. For the given model both the ``__init__()`` and ``forward()`` functions must be defined in the class defining the network. An example of this is shown in `examples/mnist.py <https://github.com/helmholtz-analytics/heat/blob/504-docstring-formatting/examples/nn/mnist.py>`_. It is highly recommended that a HeAT DataLoader is used, see :func:`ht.utils.data.DataLoader <heat.utils.data.datatools.DataLoader>`. The default communications scheme for this is blocking. The blocking scheme will average the model parameters during the backwards step, synchronizing them before the next model iteration. Usage of more than one optimizer forces MPI communication to be parameter updates to use blocking communications. Attributes ---------- module : torch.nn.Module The local module comm : MPICommunication Communicator to use optimizer : heat.DataParallelOptimizer, List, Tuple Individual or sequence of DataParallelOptimizers to be used blocking_parameter_updates : bool, optional Flag indicating the usage of blocking communications for parameter updates Default: non-blocking updates (``False``) """ def __init__( self, module: torch.nn.Module, comm: MPICommunication, optimizer: Union[optim.DataParallelOptimizer, List, Tuple], blocking_parameter_updates: bool = False, ): # noqa: D107 if isinstance(optimizer, optim.DASO): raise TypeError( "For use with DASO please use DataParallelMultiGPU instead of DataParallel" ) super(DataParallel, self).__init__() self.module = module self.comm = comm self.blocking_parameter_updates = blocking_parameter_updates self._dp_optimizers = [] self._layer_wait_handles = OrderedDict() self._fwd_hook_handles = [] # set of layers' names with active wait handles (only relevant for non-blocking) self._active_layers = set() # slices of parameters belonging to one and the same layer self._param_slices = {} # pytorch internal parameter indexing self._param_indices = {} # raise error if no DP optimizer is given if not isinstance(optimizer, (list, tuple)): optimizer = [optimizer] for i in optimizer: if not isinstance(i, optim.DataParallelOptimizer): raise TypeError("optimizers must be optim.DataParallelOptimizer") # current implementation of non-blocking communication during parameter updates has some limitations that cause # fallback onto blocking in case of overstepping them if not self.blocking_parameter_updates and ( len(optimizer) > 1 or list(module.parameters()) != optimizer[0].torch_optimizer.param_groups[0]["params"] ): self.blocking_parameter_updates = True warnings.warn( "Usage of more than one DataParallelOptimizer causes fallback on blocking MPI " "communication during parameter updates.", stacklevel=2, ) # assign given optimizers to this model for dp_optimizer in optimizer: self._dp_optimizers.append(dp_optimizer) dp_optimizer.blocking_parameter_updates = self.blocking_parameter_updates # unify parameters across nodes by unifying the random seed and resetting parameters torch.random.manual_seed(2147483646) # max int32 value - 1 self.module.apply(self._reset_parameters) # get parameter indexing and slices start_idx = 0 layer_name_prev = None for idx, (name, param) in enumerate(module.named_parameters()): self._param_indices[name] = idx layer_name = name.rsplit(sep=".", maxsplit=1)[0] if layer_name_prev is None: layer_name_prev = layer_name if layer_name_prev != layer_name: self._param_slices[layer_name_prev] = slice(start_idx, idx) layer_name_prev = layer_name start_idx = idx # register backward hooks for all model parameter tensors if self.blocking_parameter_updates: param.register_hook(self._blocking_hook) else: param.register_hook(self._nonblocking_hook(layer_name, name)) self._param_slices[layer_name_prev] = slice(start_idx, len(self._param_indices))
[docs] def __setattr__(self, name: str, value: Union[torch.nn.Module, torch.Tensor, Any]) -> None: """ Overwrite the current torch.nn.Module.__setattr__ so that it auto-detects the end of epoch's training phase and finalize wait handles (only relevant for non-blocking) """ if name == "training" and not value and not self.blocking_parameter_updates: self._iparam_update() super(DataParallel, self).__setattr__(name, value)
[docs] def forward(self, *inputs: tuple, **kwargs: dict) -> torch.Tensor: """ Do the forward step for the network, receive the parameters from the last """ # check if non-blocking and training if not self.blocking_parameter_updates and self.module.training: # register forward hooks for all layers with wait handles for name, submodule in self.module.named_modules(): if name == "": continue if name in self._layer_wait_handles: hook_handle = submodule.register_forward_pre_hook(self._forward_hook(name)) self._fwd_hook_handles.append(hook_handle) # perform forward pass ret = self.module(*inputs, **kwargs) # finalize potentially remaining wait handles and update corresponding params (if # computation graph has changed between previous backward and this forward) if not self.blocking_parameter_updates and self.module.training: # set has to be copied in order to be manipulated during iteration active_layers_cpy = self._active_layers.copy() for layer_name in active_layers_cpy: self._forward_hook(layer_name)(None, None) # reset optimizer flag for ldp_optimizer in self._dp_optimizers: ldp_optimizer.update_next = False # clear dictionary after all wait handles are used up (dynamic computation graph) self._layer_wait_handles.clear() # remove forward hooks (dynamic computation graph) for hook_handle in self._fwd_hook_handles: hook_handle.remove() self._fwd_hook_handles.clear() return ret
[docs] def _iparam_update(self, param_slice: slice = None, layer_names: List[str] = None) -> None: r""" Update parameters asynchronously via wait handles. Parameters ---------- param_slice : slice, optional Slice object for creating a view onto optimizer's params list.\n By default, the whole params list is used, (``None``) layer_names : list(str), optional List of layer names which parameters will be updated, must match param_slice.\n By default, all layers are updated (``None``) """ # for non-blocking, only one dp optimizer is allowed dp_optimizer = self._dp_optimizers[0] # perform update on the whole model if param_slice is None: param_slice = slice(len(dp_optimizer.params_ref)) if layer_names is None: layer_names = list(self._layer_wait_handles.keys()) # update params that are visible for the optimizer dp_optimizer.torch_optimizer.param_groups[0]["params"] = dp_optimizer.params_ref[ param_slice ] # iterate over layers for layer_name in reversed(layer_names): # only perform update, if all given layers hold unfinalized wait handles (important for layer reuse) if layer_name not in self._active_layers: return # iterate over layer's parameters/associated wait handles for param_name, wait_handle, dtp, tens in self._layer_wait_handles[layer_name]: # get internal index of selected parameter param_idx = self._param_indices[param_name] # synchronize, get parameter's global gradient wait_handle.wait() # check if shapes are matching if ( dp_optimizer.params_ref[param_idx].grad.data.shape != tens.shape ): # wait_handle.tensor.shape: raise ValueError("Shapes must be equal.") # accumulate parameter's global gradient dp_optimizer.params_ref[param_idx].grad.data += tens.to(dtp) # wait_handle.tensor # remove layer from set of active layers, if present self._active_layers.discard(layer_name) # if desired, perform actual parameter update if dp_optimizer.update_next: dp_optimizer.torch_optimizer.step()
[docs] def _blocking_hook(self, grad_loc: torch.Tensor) -> torch.Tensor: """ Add a blocking hook to the PyTorch DAG for all of the backwards calls. Parameters ---------- grad_loc : torch.Tensor The local gradient References ---------- [1] (cf. https://pytorch.org/docs/stable/tensors.html#torch.Tensor.register_hook). """ grad_loc_bf = grad_loc.to(torch.float) # bfloat16) # average local gradients grad_loc_bf *= 1 / float(self.comm.size) # perform MPI Allreduce to compute global gradient self.comm.Allreduce(MPI.IN_PLACE, grad_loc_bf, MPI.SUM) # mpi_sum_bf16) return grad_loc_bf.to(grad_loc.dtype)
[docs] def _nonblocking_hook(self, layer_name: str, param_name: str) -> Callable: """ Add a nonblocking hook to send and receive the averaged parameters after the backwards step Parameters ---------- layer_name : str Name of the layer param_name : str Name of the parameter """ # hook function for blocking gradient data exchange def _hook(grad_loc: torch.Tensor) -> torch.Tensor: with torch.no_grad(): wrk = grad_loc.to(torch.float) # bfloat16) # counterbalance local gradient averaging wrk *= 1 / float(self.comm.size) # perform MPI IAllreduce to compute global gradient, returns wait handle wait_handle = self.comm.Iallreduce(MPI.IN_PLACE, wrk, MPI.SUM) # mpi_sum_bf16) # if layer wait handle dict does not contain the layer, add it -> automatically tracks reversed layer order if layer_name not in self._layer_wait_handles: self._layer_wait_handles[layer_name] = [] # add layer to set of active layers self._active_layers.add(layer_name) # assign wait handle to its layer, layer-internal sorting by size # bisect.insort( # self._layer_wait_handles[layer_name], (wrk.numel(), param_name, wait_handle) # ) # TODO: is sorting faster? or is there any difference? self._layer_wait_handles[layer_name].append( (param_name, wait_handle, grad_loc.dtype, wrk) ) # don't return grad_loc, otherwise gradient is doubled return torch.zeros(*wrk.size(), device=grad_loc.device) return _hook
[docs] def _forward_hook(self, layer_name: str) -> Callable: """ Add a forward hook to update parameters during the forward step. This will return a hook with can be added using the ``submodule.register_forward_pre_hook`` command. Parameters ---------- layer_name : str Name of the layer """ # hook function for non-blocking parameter update def _hook(_, input_): # update parameters of given layer param_slice = self._param_slices[layer_name] self._iparam_update(param_slice, [layer_name]) return input_ return _hook
[docs] @staticmethod def _reset_parameters(module: tnn.Module) -> None: """ Reset parameters of given torch submodule. Only works for basic module types containing ``reset_parameters`` function. Parameters ---------- module: torch.nn.Module Submodule whose parameters are to be reset """ if callable(getattr(module, "reset_parameters", None)): module.reset_parameters()
[docs] class DataParallelMultiGPU(tnn.Module): """ Creates data parallel networks local to each node using PyTorch's distributed class. This does NOT do any global synchronizations. To make optimal use of this structure, use :func:`ht.optim.DASO <heat.optim.dp_optimizer.DASO>`. Notes ----- The PyTorch distributed process group must already exist before this class is initialized. Parameters ---------- module: torch.nn.Module an implemented PyTorch model optimizer: optim.DASO A DASO optimizer. Other optimizers are not yet implemented. The DASO optimizer should be defined prior to calling this class. comm: MPICommunication, optional A global communicator. Default: :func:`MPICommunication <heat.core.comm.MPICommunication>` """ def __init__( self, module: torch.nn.Module, optimizer: optim.DASO, comm: MPICommunication = MPI_WORLD ): # noqa: D107 super(DataParallelMultiGPU, self).__init__() rank = comm.rank if torch.cuda.device_count() > 1: self.loc_gpus = torch.cuda.device_count() local_rank = rank % self.loc_gpus device = f"cuda:{str(local_rank)}" torch.cuda.set_device(device=device) module = tnn.parallel.DistributedDataParallel(module, device_ids=[local_rank]) else: warnings.warn( "DataParallelMultiGPU should be used with multiple GPUs per node", UserWarning ) self.module = module self.comm = comm # unify parameters across nodes by unifying the random seed and resetting parameters self.module.apply(self._reset_parameters) optimizer.set_model(self.module)
[docs] def forward(self, *inputs: Tuple, **kwargs: Dict) -> torch.Tensor: """ Calls the forward method for the torch model """ return self.module(*inputs, **kwargs)
[docs] @staticmethod def _reset_parameters(module: tnn.Module) -> None: """ Reset parameters of given torch submodule. Only works for basic module types containing ``reset_parameters`` function. Parameters ---------- module: torch.nn.Module Submodule whose parameters are to be reset """ if callable(getattr(module, "reset_parameters", None)): module.reset_parameters()