:mod:`heat.optim.dp_optimizer` ============================== .. py:module:: heat.optim.dp_optimizer .. autoapi-nested-parse:: MPI enabled data parallel optimizers Module Contents --------------- .. py:class:: DASO(local_optimizer: torch.optim.Optimizer, total_epochs: int, comm: heat.core.communication.MPICommunication = MPI_WORLD, warmup_epochs: int = 4, cooldown_epochs: int = 4, scheduler: torch.optim.lr_scheduler = None, stability_level: float = 0.05, max_global_skips: int = 8, sending_chunk_size: int = 10000000, downcast_type: torch.dtype = torch.bfloat16, use_mpi_groups: bool = True, skip_reduction_factor: int = 2, local_skip_factor: int = 4, verbose: bool = False) Optimizer wrapper to use the Distributed Asynchronous and Selective Optimization (DASO) method. This optimizer uses a local torch optimizer combined with the :func:`nn.DataParallelMultiGPU ` to create local DPNNs on each node consisting of the GPUs on each node. Then those networks communicate globally with MPI groups, each of which has a single GPU on each node. DASO uses both local and global synchronization operations. Local synchronization operations are intended to be done very frequently while global synchronizations are conducted asynchronously as the next batches are computed. This implementation requires that all nodes have the name number of GPUs. There are four phases to training: 1. initialization: steps 1 to 8 below 2. Warmup phase: blocking averaging update occurs for global synchronization step 3. Cycling phase: for the global synchronization, the data is sent after a number of batches. the number of batches between synchronizations is referred to as `global_skips`. After the data is sent a number of batches pass before it is received (`batches_to_wait`). both of these cycle downward from `max_global_skips` for the global skips and 1/4th this value for `batches_to_wait`. When both values are equal to 1 and the loss is stable it will be reset to the initial values, then will decay again. 4. Cooldown phase: blocking averaging update occurs for global synchronization step As example usage of this can be found in `heat/examples/nn/imagenet-DASO.py `_. The recommended checklist for using this class is as follows: 1. initialize the local PyTorch process group and set the default device of the local GPUs. 2. define the torch network 3. define the `local_optimizer` -> a torch optimizer of your choice (tested with SGD) 4. optional, choose a learning rate scheduler. This is only for those learning rates which will also step the optimizer 5. initialize DASO with the local optimizers and parameters 6. initialize :func:`nn.DataParallelMultiGPU ` with the torch network and DASO 7. If using automatic mixed precision (:class:`torch.cuda.amp`), initialize the gradient scaler and add it to DASO (:func:`add_scaler`) 8. ensure that the DataLoaders evenly distribute the data between all the processes. This can be done by using the `torch.utils.data.distributed.DistributedSampler `_ with the `num_replicas` and `rank` parameters 9. call `daso_optimizer.epoch_loss_logic(training_loss)` at the end of 10. set the number of batches per epoch (`daso_optimizer.last_batch = number_of_batches`) 11. ensure that the step function used in training is that of the DASO optimizer :param local_optimizer: This optimizer handles the optimization of the local NN. Example: `torch.optim.SGD`. \n This can be any optimizer, although tests were only completed with SGD. Other optimizers may show unexpected behavior. :type local_optimizer: torch.optim.Optimizer :param total_epochs: The total number of epochs for training. Needed to determine when to enter the cooldown phase. :type total_epochs: int :param comm: The MPI communicator to use for training. \n Default: :func:`MPI_WORLD ` :type comm: MPICommunication, optional :param warmup_epochs: The number of epochs to complete with a blocking averaging operation after each batch before entering the cycling phase.\n Default: 4 :type warmup_epochs: int, optional :param cooldown_epochs: The number of epochs with blocking averaging operations after each batch at the end of training.\n Default: 4 :type cooldown_epochs: int, optional :param scheduler: Local PyTorch learning rate scheduler. This must be used in the case that the scheduler's `step` function is supposed to be called instead of the optimizer's `step` function.\n Default: None :type scheduler: torch.optim.lr_scheduler, optional :param stability_level: This can be viewed as the percent change threshold that the loss must exceed to be judged as improving. When the loss is within this percent change for 2 epochs, then it is judged as stable.\n Default: 0.05 :type stability_level: float, optional :param max_global_skips: The maximum number of batches between the beginning of a global synchronization process.\n Default: 8 :type max_global_skips: int, optional :param sending_chunk_size: During the global synchronization step, the network parameters are split into chunks of data to overlap communication and computation. This value is the maximum chunk size.\n Default: 10,000,000 :type sending_chunk_size: int, optional :param downcast_type: Options: [torch.bfloat16, torch.half, torch.float] When the network parameters are sent during the global synchronization step, they are cast down to a smaller dtype, by default this is `torch.bfloat16`. Smaller torch dtypes are not implemented. torch.bfloat16.\n Default: torch.bfloat16 :type downcast_type: torch.dtype, optional :param use_mpi_groups: Use MPI groups to divide the global communicator. If True, use MPI GROUPs, otherwise, use MPI SPLIT.\n Default: True :type use_mpi_groups: bool, optional :param skip_reduction_factor: How much to reduce the global/local skips by when the loss has stabilized.\n Default: 2 :type skip_reduction_factor: int, optional :param local_skip_factor: How many local skips occur per global skip, i.e. number of local skips = global_skips // local_skip_factor.\n Default: 4 :type local_skip_factor: int, optional :param verbose: If true, print out a collection of debug messages.\n Default: False :type verbose: bool, optional .. attribute:: cast_dtype :annotation: = Ellipsis .. attribute:: comm .. attribute:: verbose :annotation: = False .. attribute:: local_optimizer .. attribute:: params_ref .. attribute:: scheduler :annotation: = None .. attribute:: loc_gpus :annotation: = 0 .. attribute:: local_skip :annotation: = 1 .. attribute:: _prev_params :annotation: = [] .. attribute:: epoch :annotation: = 0 .. attribute:: global_skip :annotation: = 0 .. attribute:: batches_to_wait :annotation: = 0 .. attribute:: max_gs :annotation: = 8 .. attribute:: warmup_epochs :annotation: = 4 .. attribute:: cooldown_epochs :annotation: = 4 .. attribute:: total_epochs .. attribute:: _param_send_buffer_size :annotation: = None .. attribute:: _param_send_shp :annotation: = None .. attribute:: split :annotation: = None .. attribute:: skip_reduction_factor :annotation: = 2 .. attribute:: local_skip_factor :annotation: = 4 .. attribute:: stability .. attribute:: _gs8_waits :annotation: = 3 .. attribute:: _gs8_waited :annotation: = 0 .. attribute:: split_val :annotation: = 10000000 .. attribute:: split_inds :annotation: = None .. attribute:: amp :annotation: = False .. role:: raw-html(raw) :format: html .. method:: add_scaler(scaler: torch.cuda.amp.GradScaler) -> None Create a reference to torch's `torch.cuda.amp.GradScaler `_ used in torch's automatic mixed precision. :param scaler: the gradient scaler to be used :type scaler: torch.cuda.amp.GradScaler .. method:: __init_checktypes(args: Dict) -> None .. method:: epoch_loss_logic(loss: Union[torch.Tensor, int, float], loss_globally_averaged: bool = False) -> None Function controlling the number of batches between global synchronizations and the batches to wait before receiving the sent parameters. The warm-up and cool-down phases are also controlled here. This function should be called at the end of each epoch with the training loss value at the end of the epoch. The number of batches between local synchronizations can also be modified here with minor code adjustments. :param loss: loss value of the current epoch :type loss: torch.Tensor or float :param loss_globally_averaged: boolean if the loss is already globally averaged :type loss_globally_averaged: bool, optional .. method:: _global_sync(batches_to_wait: int) -> None Performs a global synchronization. If `batches_to_wait > 0` this will wait for that many batches before received in the parameters. Full syncs are only performed on a single MPI group .. method:: _gs_create_param_dict() -> Tuple[Dict, Dict] Create the shape and param dictionary used for sending parameters around the MPI world. this will also define the buffer size if it was not previously defined. .. method:: _gs_rcv_update_params() -> None Receive the previously sent parameters for the last sending MPI group. this is also where the sent and local parameters are merged together. .. method:: _gs_rcv_update_params_last_batch(current_ranks: Tuple) -> None Abstracted receive for the last batch (and if `global_skips` == 0) .. method:: _gs_send_params(current_comm: heat.core.communication.MPICommunication, batches_to_wait: int) -> None Pack and send the data required for a global synchronization on the `current_comm` group `batches_to_wait` is sent with the parameters to keep track of this between sending and receiving .. method:: _local_update(sending_process: Tuple) -> None .. method:: __pack_data(jtparams: torch.Tensor, iter_dict: Dict[str, torch.Tensor], cast: int) -> torch.Tensor Jitted loop to pack the data into a flattened buffer to be sent .. method:: print0(*args, **kwargs) -> None Print a message on rank 0 if the class parameter `verbose` is set. .. method:: reset() -> None Reset the optimizer to its base state .. method:: set_model(model: torch.nn.Module) -> None Set the local model for the optimizer. This should be called during the init of :func:`nn.DataParallelMultiGPU `. However, this can also be called manually. :param model: the local torch model. :type model: torch.nn.Module .. method:: _start_local_sync() -> None *Start* local synchronizations for the next batches .. method:: step() -> None Perform a single optimization step. This will perform the `step` operations of the local optimizer, local learning rate scheduler (if defined), and the gradient scaler used in automatic mixed precision (if defined). Also in the step is the logic used for when to send and receive the global/local synchronizations. Global Syncs occur on batches for which the modulus of the batch number and the `global_skip` number is 0. If `batches_to_wait` > 0, the next batches have only local syncs. After that number of batches, the data during the global sync phase is received. Local synchronization can also be turned off if desired by increasing `local_skips` above 1. .. rubric:: Notes self.last_batch must be set! .. method:: _stop_local_sync() -> None Stop local synchronizations for the next batches .. method:: zero_grad() -> None Reset gradients of local optimizer's parameters. .. py:class:: DataParallelOptimizer(torch_optimizer: torch.optim.Optimizer, blocking: bool = False) Uses a torch.optim.Optimizer for data parallelism. It should be used in combination with DataParallel (DP) class. To optimize a DP module, DP optimizer has to be passed to DP module during its initialization. See :func:`nn.DataParallel ` for a basic example of usage. :ivar torch_optimizer: the wrapped Torch optimizer :vartype torch_optimizer: torch.optim.Optimizer :ivar blocking: use blocking communications or not. will typically be overwritten by :func:`nn.DataParallel ` :vartype blocking: bool .. attribute:: torch_optimizer .. attribute:: blocking_parameter_updates :annotation: = False .. attribute:: update_next :annotation: = False .. attribute:: params_ref .. role:: raw-html(raw) :format: html .. method:: step() -> None Force torch optimizer to update model parameters. For blocking, optimizer immediately updates parameters. For non-blocking, optimizer will update parameters during next forward. .. method:: zero_grad() -> None Reset gradients of optimizer's params.