"""
MPI enabled data parallel optimizers
"""
import inspect
import math
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as tDDP
from typing import Union, List, Tuple, Dict
from ..core.communication import MPICommunication
from ..core.communication import MPI
from ..core.communication import MPI_WORLD
from .utils import DetectMetricPlateau
__all__ = ["DataParallelOptimizer", "DASO"]
def __sum_f16_cb(buffer_a, buffer_b, _):
# MPI custom sum function to use torch.half
# try/except is used to use UntypedStorages from Pytorch version >= 2.0.0 while keeping backward compatibility
try:
tens_a = torch.HalfTensor().set_(
torch.UntypedStorage.from_buffer(buffer_a, "native", dtype=torch.half)
)
tens_b = torch.HalfTensor().set_(
torch.UntypedStorage.from_buffer(buffer_b, "native", dtype=torch.half)
)
except AttributeError:
tens_a = torch.HalfTensor().set_(torch.HalfStorage.from_buffer(buffer_a, "native"))
tens_b = torch.HalfTensor().set_(torch.HalfStorage.from_buffer(buffer_b, "native"))
tens_b += tens_a
nelem = torch.prod(torch.tensor(tens_b.shape)).item()
new_buff = MPI.memory.fromaddress(tens_b.data_ptr(), nbytes=tens_b.element_size() * nelem)
buffer_b[:] = new_buff
def __sum_bfloat_cb(buffer_a, buffer_b, _):
# MPI custom sum function to use torch.bfloat16
# try/except is used to use UntypedStorages from Pytorch version >= 2.0.0 while keeping backward compatibility
try:
tens_a = torch.BFloat16Tensor().set_(
torch.UntypedStorage.from_buffer(buffer_a, "native", dtype=torch.bfloat16)
)
tens_b = torch.BFloat16Tensor().set_(
torch.UntypedStorage.from_buffer(buffer_b, "native", dtype=torch.bfloat16)
)
except AttributeError:
tens_a = torch.BFloat16Tensor().set_(torch.BFloat16Storage.from_buffer(buffer_a, "native"))
tens_b = torch.BFloat16Tensor().set_(torch.BFloat16Storage.from_buffer(buffer_b, "native"))
tens_b += tens_a
nelem = int(tens_b.numel())
new_buff = MPI.memory.fromaddress(tens_b.data_ptr(), nbytes=nelem * tens_b.element_size())
buffer_b[:] = new_buff
# create new MPI OPs
mpi_sum_f16 = MPI.Op.Create(__sum_f16_cb, commute=True)
mpi_sum_bfloat = MPI.Op.Create(__sum_bfloat_cb, commute=True)
[docs]
class DASO:
r"""
Optimizer wrapper to use the Distributed Asynchronous and Selective Optimization (DASO) method.
This optimizer uses a local torch optimizer combined with the :func:`nn.DataParallelMultiGPU <heat.nn.data_parallel.DataParallelMultiGPU>`
to create local DPNNs on each node consisting of the GPUs on each node. Then those networks communicate
globally with MPI groups, each of which has a single GPU on each node.
DASO uses both local and global synchronization operations. Local synchronization operations are intended to be
done very frequently while global synchronizations are conducted asynchronously as the next batches are
computed.
This implementation requires that all nodes have the name number of GPUs.
There are four phases to training:
1. initialization: steps 1 to 8 below
2. Warmup phase: blocking averaging update occurs for global synchronization step
3. Cycling phase: for the global synchronization, the data is sent after a number of batches. the number of batches between synchronizations is referred to as `global_skips`. After the data is sent a number of batches pass before it is received (`batches_to_wait`). both of these cycle downward from `max_global_skips` for the global skips and 1/4th this value for `batches_to_wait`. When both values are equal to 1 and the loss is stable it will be reset to the initial values, then will decay again.
4. Cooldown phase: blocking averaging update occurs for global synchronization step
As example usage of this can be found in `heat/examples/nn/imagenet-DASO.py <https://github.com/helmholtz-analytics/heat/blob/504-docstring-formatting/examples/nn/imagenet-DASO.py>`_.
The recommended checklist for using this class is as follows:
1. initialize the local PyTorch process group and set the default device of the local GPUs.
2. define the torch network
3. define the `local_optimizer` -> a torch optimizer of your choice (tested with SGD)
4. optional, choose a learning rate scheduler. This is only for those learning rates which will also step the optimizer
5. initialize DASO with the local optimizers and parameters
6. initialize :func:`nn.DataParallelMultiGPU <heat.nn.data_parallel.DataParallelMultiGPU>` with the torch network and DASO
7. If using automatic mixed precision (:class:`torch.cuda.amp`), initialize the gradient scaler and add it to DASO (:func:`add_scaler`)
8. ensure that the DataLoaders evenly distribute the data between all the processes. This can be done by using the `torch.utils.data.distributed.DistributedSampler <https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler>`_ with the `num_replicas` and `rank` parameters
9. call `daso_optimizer.epoch_loss_logic(training_loss)` at the end of
10. set the number of batches per epoch (`daso_optimizer.last_batch = number_of_batches`)
11. ensure that the step function used in training is that of the DASO optimizer
Parameters
----------
local_optimizer: torch.optim.Optimizer
This optimizer handles the optimization of the local NN. Example: `torch.optim.SGD`. \n
This can be any optimizer, although tests were only completed with SGD. Other optimizers may show
unexpected behavior.
total_epochs: int
The total number of epochs for training. Needed to determine when to enter the cooldown phase.
comm: MPICommunication, optional
The MPI communicator to use for training. \n
Default: :func:`MPI_WORLD <heat.core.comm.MPI_WORLD>`
warmup_epochs: int, optional
The number of epochs to complete with a blocking averaging operation after each batch before entering
the cycling phase.\n
Default: 4
cooldown_epochs: int, optional
The number of epochs with blocking averaging operations after each batch at the end of training.\n
Default: 4
scheduler: torch.optim.lr_scheduler, optional
Local PyTorch learning rate scheduler. This must be used in the case that the scheduler's `step` function
is supposed to be called instead of the optimizer's `step` function.\n
Default: None
stability_level: float, optional
This can be viewed as the percent change threshold that the loss must exceed to be judged as improving.
When the loss is within this percent change for 2 epochs, then it is judged as stable.\n
Default: 0.05
max_global_skips: int, optional
The maximum number of batches between the beginning of a global synchronization process.\n
Default: 8
sending_chunk_size: int, optional
During the global synchronization step, the network parameters are split into chunks of data to overlap
communication and computation. This value is the maximum chunk size.\n
Default: 10,000,000
downcast_type: torch.dtype, optional
Options: [torch.bfloat16, torch.half, torch.float]
When the network parameters are sent during the global synchronization step, they are cast down to
a smaller dtype, by default this is `torch.bfloat16`. Smaller torch dtypes are not implemented.
torch.bfloat16.\n
Default: torch.bfloat16
use_mpi_groups: bool, optional
Use MPI groups to divide the global communicator. If True, use MPI GROUPs, otherwise, use MPI SPLIT.\n
Default: True
skip_reduction_factor: int, optional
How much to reduce the global/local skips by when the loss has stabilized.\n
Default: 2
local_skip_factor: int, optional
How many local skips occur per global skip, i.e. number of local skips = global_skips // local_skip_factor.\n
Default: 4
verbose: bool, optional
If true, print out a collection of debug messages.\n
Default: False
"""
def __init__(
self,
local_optimizer: torch.optim.Optimizer,
total_epochs: int,
comm: MPICommunication = MPI_WORLD,
warmup_epochs: int = 4,
cooldown_epochs: int = 4,
scheduler: torch.optim.lr_scheduler = None,
stability_level: float = 0.05,
max_global_skips: int = 8,
sending_chunk_size: int = 10_000_000,
downcast_type: torch.dtype = torch.bfloat16,
use_mpi_groups: bool = True,
skip_reduction_factor: int = 2,
local_skip_factor: int = 4,
verbose: bool = False,
): # noqa: D107
# check dtypes
frame = inspect.currentframe()
init_args = inspect.getargvalues(frame)[3]
self.__init_checktypes(init_args)
self.cast_dtype = downcast_type
if downcast_type == torch.bfloat16:
self.cast_fn = mpi_sum_bfloat
elif downcast_type == torch.half:
self.cast_fn = mpi_sum_f16
else:
self.cast_fn = MPI.SUM
self.comm = comm
self.verbose = verbose
self.local_optimizer = local_optimizer
self.params_ref = local_optimizer.param_groups[0]["params"]
# reference of optimizer's params
self.scheduler = scheduler
rank = self.comm.rank
loc_gpus = torch.cuda.device_count()
# this assumes that there are an equal number of GPUs per node,
# if a change is desired a comm her to find the lowest number would work for this, however
# a change would also need to be made in heat.nn.DataParallelMultiGPU
self.loc_gpus = loc_gpus
local_rank = rank % loc_gpus
self.local_skip = 1
if loc_gpus > 1:
base_loc_ranks = list(range(0, self.comm.size, loc_gpus))
reduced_comms, reduced_ranks = [], []
for i in range(loc_gpus):
lp_ranks = [j + i for j in base_loc_ranks]
if use_mpi_groups:
new_group = self.comm.group.Incl(lp_ranks)
new_comm = self.comm.Create_group(new_group)
reduced_comms.append(MPICommunication(new_comm))
else:
color = 111 + i if rank in lp_ranks else 222 + i
key = 0 + i if rank in lp_ranks else 444 + i
reduced_comms.append(MPICommunication(self.comm.Split(color, key)))
reduced_ranks.append(tuple(lp_ranks))
self.reduced_comms, self.reduced_ranks = reduced_comms, reduced_ranks
self.base_loc_ranks = base_loc_ranks
self.device = f"cuda:{str(local_rank)}"
torch.cuda.set_device(device=self.device)
self.current_batch, self.last_batch = 0, None
self._prev_params = []
self.epoch = 0
self._send_mod, self._send_mod_m1 = 0, None
self.global_skip = 0
self.local_skip = 0
self.batches_to_wait = 0
self.max_gs = max_global_skips
self.warmup_epochs = warmup_epochs
self.cooldown_epochs = cooldown_epochs
self.total_epochs = total_epochs
# used in the sending of the params
self._param_send_buffer_size = None
self.param_dict, self.shapes = None, None
self._param_send_shp = None
self.split = None
self.skip_reduction_factor = skip_reduction_factor
# the local_skip_factor is the factor by which the global skips are divided by initially
# and upon reset
self.local_skip_factor = local_skip_factor
self.stability = DetectMetricPlateau(patience=2, threshold=stability_level)
self._gs8_waits = 3
self._gs8_waited = 0
self.split_val = sending_chunk_size
# TODO: its possible that the split indexes could be used to avoid the concatenating method used currently
self.split_inds = None
self.amp = False
self.print0("Finished DASO init")
[docs]
def add_scaler(self, scaler: torch.cuda.amp.GradScaler) -> None:
"""
Create a reference to torch's `torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/notes/amp_examples.html>`_ used in torch's automatic mixed
precision.
Parameters
----------
scaler: torch.cuda.amp.GradScaler
the gradient scaler to be used
"""
self.scaler = scaler
self.amp = True
@staticmethod
def __init_checktypes(args: Dict) -> None:
# this does all of the checks and raises for the parameters for init
if not isinstance(args["local_optimizer"], torch.optim.Optimizer):
raise TypeError(
f"Local optimizer must be a torch optimizer object, currently {type(args['local_optimizer'])}"
)
if not isinstance(args["comm"], MPICommunication):
raise TypeError(
f"Comm object must be a ht.MPICommunication object, currently {type(args['comm'])}"
)
if not isinstance(args["total_epochs"], int):
raise TypeError(f"total_epochs must be an int, currently {type(args['total_epochs'])}")
if not isinstance(args["warmup_epochs"], int):
raise TypeError(
f"warmup_epochs must be an int, currently {type(args['warmup_epochs'])}"
)
if not isinstance(args["cooldown_epochs"], int):
raise TypeError(
f"cooldown_epochs must be an int, currently {type(args['cooldown_epochs'])}"
)
if args["scheduler"] is not None and not issubclass(
args["scheduler"], torch.optim.lr_scheduler._LRScheduler
):
raise TypeError(
f"scheduler must be a torch learning rate scheduler, currently {args['scheduler']}"
)
if not isinstance(args["stability_level"], float):
raise TypeError(
f"stability_level must be a float, currently {type(args['stability_level'])}"
)
if not isinstance(args["max_global_skips"], int):
raise TypeError(
f"max_global_skips must be an int, currently {type(args['max_global_skips'])}"
)
if not isinstance(args["sending_chunk_size"], int):
raise TypeError(
f"sending_chunk_size must be an int, currently {type(args['sending_chunk_size'])}"
)
if not isinstance(args["verbose"], bool):
raise TypeError(f"verbose must be a bool, currently {type(args['verbose'])}")
if not isinstance(args["use_mpi_groups"], bool):
raise TypeError(
f"`use_mpi_grus` must be a bool, currently {type(args['use_mpi_groups'])}"
)
if not isinstance(args["downcast_type"], torch.dtype):
raise TypeError(
f"downcast_type must be a torch.dtype, currently {args['downcast_type']}"
)
if args["downcast_type"] not in [torch.bfloat16, torch.half, torch.float]:
raise ValueError(
f"downcast_type must be one of [torch.bfloat16, torch.half, torch.float], "
f"currently {args['downcast_type']}"
)
if not isinstance(args["skip_reduction_factor"], int):
raise TypeError(
f"skip_reduction_factor must be an integer, currently {type(args['skip_reduction_factor'])}"
)
if not isinstance(args["local_skip_factor"], int):
raise TypeError(
f"local_skip_factor must be an integer, currently {type(args['local_skip_factor'])}"
)
if args["warmup_epochs"] < 0:
raise ValueError(f"warmup_epochs must be >= 0, currently {args['warmup_epochs']}")
if args["cooldown_epochs"] < 0:
raise ValueError(f"cooldown_epochs must be >= 0, currently {args['cooldown_epochs']}")
if args["max_global_skips"] < 0:
raise ValueError(f"stablitiy_level must be >= 0, currently {args['max_global_skips']}")
if args["sending_chunk_size"] <= 0:
raise ValueError(
f"sending_chunk_size must be > 0, currently {args['sending_chunk_size']}"
)
if args["total_epochs"] <= 0:
raise ValueError(f"total_epochs must be > 0, currently {args['total_epochs']}")
if args["skip_reduction_factor"] <= 0:
raise ValueError(
f"skip_reduction_factor must be > 0, currently {args['skip_reduction_factor']}"
)
if args["local_skip_factor"] <= 0:
raise ValueError(
f"local_skip_factor must be > 0, currently {args['local_skip_factor']}"
)
[docs]
@torch.no_grad()
def epoch_loss_logic(
self, loss: Union[torch.Tensor, int, float], loss_globally_averaged: bool = False
) -> None:
"""
Function controlling the number of batches between global synchronizations and the batches to wait before
receiving the sent parameters. The warm-up and cool-down phases are also controlled here.
This function should be called at the end of each epoch with the training loss value at the end of the epoch.
The number of batches between local synchronizations can also be modified here with minor code adjustments.
Parameters
----------
loss: torch.Tensor or float
loss value of the current epoch
loss_globally_averaged: bool, optional
boolean if the loss is already globally averaged
"""
if not loss_globally_averaged:
loss_send = torch.zeros(self.comm.size)
# loss.data -> this will get the raw number from the lass value and nothing else
loss_send[self.comm.rank] = loss.data if isinstance(loss, torch.Tensor) else loss
self.comm.Allreduce(MPI.IN_PLACE, loss_send, MPI.SUM)
avg_loss = torch.mean(loss_send)
else:
avg_loss = torch.tensor(loss)
if self.epoch < self.warmup_epochs:
self.global_skip = 0
self.local_skip = 0
self.batches_to_wait = 0
self.print0(
f"Warmup Phase, parameters of next epoch\n\t Global Skips: {self.global_skip}, "
f" Local Skips {self.local_skip}, Batches to wait: {self.batches_to_wait}"
)
return
elif self.warmup_epochs == self.epoch:
self.global_skip = 4
self.local_skip = 1
self.batches_to_wait = 1
self.print0(
f"End of Warmup Phase, parameters of next epoch\n\t Global Skips: {self.global_skip}, "
f" Local Skips {self.local_skip}, Batches to wait: {self.batches_to_wait}"
)
if self.epoch >= self.total_epochs - self.cooldown_epochs:
self.global_skip = 0
self.local_skip = 0
self.batches_to_wait = 0
self.print0(
f"Cooldown Phase, parameters of next epoch:\n\tGlobal Skips: {self.global_skip}, "
f" Local Skips {self.local_skip:.4f}, Batches to wait: {self.batches_to_wait}"
)
return
if self.global_skip == self.max_gs and self.max_gs > 4:
self._gs8_waited += 1
self.print0(
f"Best loss value: {self.stability.best * (1.0 - self.stability.threshold):.4f}"
f" Current loss: {avg_loss:.4f}, Worse epochs: {self.stability.num_bad_epochs}"
)
stable = self.stability.test_if_improving(avg_loss)
if stable and self.global_skip > 1:
# drop gs by factor of 2
self.global_skip //= self.skip_reduction_factor
self.local_skip //= self.skip_reduction_factor
self.batches_to_wait -= 1 # old was //= 2
self.print0("dropping skips")
if self.global_skip > 0:
if self.batches_to_wait == 0:
self.batches_to_wait = 1
if self.local_skip == 0:
self.local_skip = 1
self._gs8_waited = 0
elif self.global_skip == 1 and stable:
self.global_skip = self.max_gs
self.local_skip = self.max_gs // self.local_skip_factor
self.batches_to_wait = self.max_gs // self.local_skip_factor
self._gs8_waited = 0
self.print0(
f"\tNext Parameters: Global Skips: {self.global_skip}, Local Skips {self.local_skip}, "
f" Batches to wait: {self.batches_to_wait}, \n\tCurrent loss: {avg_loss:.4f}, "
f" Worse epochs: {self.stability.num_bad_epochs}"
)
[docs]
@torch.no_grad()
def _global_sync(self, batches_to_wait: int) -> None:
"""
Performs a global synchronization. If `batches_to_wait > 0` this will wait for that many
batches before received in the parameters.
Full syncs are only performed on a single MPI group
"""
current_comm = self.reduced_comms[self._send_mod]
current_ranks = self.reduced_ranks[self._send_mod]
if self.comm.rank in current_ranks:
self._gs_send_params(current_comm, batches_to_wait)
if self.batches_to_wait != 0:
# update parameters from the last sending (if there)
self._gs_rcv_update_params() # -> splits off irrelevant ranks
# needs to happen on all ranks:
self._local_update(self._send_mod_m1)
if self.current_batch == self.last_batch or self.batches_to_wait == 0:
# receive the sent data to sync params across all ranks
if self.comm.rank in current_ranks:
self._gs_rcv_update_params_last_batch(current_ranks)
else:
if len(self._prev_params) > 0:
raise ValueError(
f"DEBUG: OFF RANKS! len(prev_params) > 0! {len(self._prev_params)}"
f" batch number {self.current_batch}"
)
self._local_update(self._send_mod)
self._send_mod_m1 = None
if self.current_batch == self.last_batch:
self._send_mod = 0
self.epoch += 1
self.current_batch = 0
else:
self.current_batch += 1
self._send_mod = self._send_mod + 1 if self._send_mod <= self.loc_gpus - 2 else 0
else:
self.current_batch += 1
self._send_mod_m1 = self._send_mod
self._send_mod = self._send_mod + 1 if self._send_mod <= self.loc_gpus - 2 else 0
[docs]
def _gs_create_param_dict(self) -> Tuple[Dict, Dict]:
"""
Create the shape and param dictionary used for sending parameters around the MPI world.
this will also define the buffer size if it was not previously defined.
"""
if self.shapes is not None:
return self.param_dict, self.shapes
param_dict = {}
shapes = {}
st = 0
for name, param in self.module.named_parameters():
param_dict[name] = param
numel = param.numel()
shapes[name] = [param.shape, slice(st, st + numel), param.dtype]
st += numel
if self._param_send_buffer_size is None:
# use the total number of elements to define the sending buffer shape (single int)
self._param_send_buffer_size = st
self.param_dict = param_dict
self.shapes = shapes
return param_dict, shapes
[docs]
@torch.no_grad()
def _gs_rcv_update_params(self) -> None:
"""
Receive the previously sent parameters for the last sending MPI group.
this is also where the sent and local parameters are merged together.
"""
# wait for the global sync data and update on the selected rank,
if self._send_mod_m1 is None:
return
prev_ranks = self.reduced_ranks[self._send_mod_m1]
if self.comm.rank not in prev_ranks or len(self._prev_params) == 0:
# if no old gradients, return without doing anything
return
prev_params = self._prev_params.pop(0)
batches_between = float(prev_params[3])
shapes = prev_params[2]
# add the weighted average to the received params
numer = batches_between * 2.0 if batches_between > 0.0 else 1.0
denom = float(len(prev_ranks) + numer)
factor = numer / denom
if not self.split:
# only a single buffer
prev_params[0].Wait()
rcv_params = prev_params[1] / denom
for name, param in self.module.named_parameters():
if param.requires_grad:
update = (
rcv_params[shapes[name][1]].reshape(shapes[name][0]).to(shapes[name][2])
)
# NOTE: update here is the sum of the params across the processes
param *= factor
param += update
else:
# receive the first buffer
ind = 0
prev_params[0][0].Wait()
del prev_params[0][0]
rcv_params = prev_params[1][ind] / denom
# jit the parameter setting
for name, param in self.module.named_parameters():
if param.requires_grad:
if shapes[name][1].stop > len(rcv_params):
# when the end of the slice is higher than the amount received, wait for the next buffer
ind += 1
prev_params[0][0].Wait()
del prev_params[0][0]
new_rcv_params = prev_params[1][ind] / denom
rcv_params = torch.cat((rcv_params, new_rcv_params))
update = (
rcv_params[shapes[name][1]].reshape(shapes[name][0]).to(shapes[name][2])
)
# NOTE: update here is the sum of the params across the processes
param *= factor
param += update
[docs]
@torch.no_grad()
def _gs_rcv_update_params_last_batch(self, current_ranks: Tuple) -> None:
"""
Abstracted receive for the last batch (and if `global_skips` == 0)
"""
if len(self._prev_params) > 1:
raise ValueError(f"length of previous params > 1! {len(self._prev_params)}")
prev_params = self._prev_params.pop(0)
shapes = prev_params[2]
if not self.split:
prev_params[0].Wait()
rcv_params = prev_params[1] / float(len(current_ranks))
for name, param in self.module.named_parameters():
if param.requires_grad:
param[:] = (
rcv_params[shapes[name][1]].reshape(shapes[name][0]).to(shapes[name][2])
)
else:
ind1 = 0
prev_params[0][0].Wait()
del prev_params[0][0]
rcv_params = prev_params[1][ind1] / float(len(current_ranks))
for name, param in self.module.named_parameters():
if param.requires_grad:
while shapes[name][1].stop > len(rcv_params):
ind1 += 1
prev_params[0][0].Wait()
del prev_params[0][0]
new_rcv_params = prev_params[1][ind1] / float(len(current_ranks))
rcv_params = torch.cat((rcv_params, new_rcv_params))
param[:] = (
rcv_params[shapes[name][1]].reshape(shapes[name][0]).to(shapes[name][2])
)
[docs]
@torch.no_grad()
def _gs_send_params(self, current_comm: MPICommunication, batches_to_wait: int) -> None:
"""
Pack and send the data required for a global synchronization on the `current_comm` group
`batches_to_wait` is sent with the parameters to keep track of this between sending and receiving
"""
op = MPI.SUM
cast = False
cast_int = 2
if self.global_skip < 1:
# op = mpi_sum_bfloat
cast = True
op = self.cast_fn
if self.cast_dtype == torch.bfloat16:
cast_int = 0
elif self.cast_dtype == torch.half:
cast_int = 1
# else: # keep as floats (default case, see above)
param_dict, shapes = self._gs_create_param_dict()
sndparams = torch.zeros(
self._param_send_buffer_size,
device=self.device,
dtype=self.cast_dtype if cast else None,
)
sndparams = self.__pack_data(sndparams, param_dict, cast_int)
try:
nans = sndparams.isnan().sum()
except RuntimeError:
# the isnan function isnt implemented in some cuda / torch implementations
nans = sndparams.to(torch.half).isnan().sum()
if nans:
# check if there are NaNs, if so, stuff is bad
raise ValueError(f"{nans} NaNs in `params` shit be fucked.")
if not self.split and sndparams.numel() <= self.split_val:
new_wait = current_comm.Iallreduce(MPI.IN_PLACE, sndparams, op)
self._prev_params.append([new_wait, sndparams, shapes, batches_to_wait])
return
# if self.split or sndparams.numel() > self.split_val:
self.split = True
num_splits = math.ceil(len(sndparams) / self.split_val)
splits = [self.split_val] * (num_splits - 1)
rem = len(sndparams) - (self.split_val * (num_splits - 1))
# first one will be smaller then the rest (the remainder is first)
splits = [rem] + splits
self.split_inds = splits
params_list = [None] * num_splits
prev = 0
waits = [None] * num_splits
for s in range(num_splits):
# need to slice the params at the split points
params_list[s] = sndparams[prev : splits[s] + prev]
prev += splits[s]
waits[s] = current_comm.Iallreduce(MPI.IN_PLACE, params_list[s], op)
self._prev_params.append([waits, params_list, shapes, batches_to_wait])
[docs]
@torch.no_grad()
def _local_update(self, sending_process: Tuple) -> None:
# use torch to send the network parameters of a single process to the other processes
if not torch.distributed.is_initialized() or sending_process is None:
return
snds = {
name: torch.distributed.broadcast(param, sending_process, async_op=True)
for name, param in self.module.named_parameters()
if param.requires_grad
}
for name, param in self.module.named_parameters():
if param.requires_grad:
snds[name].wait()
@staticmethod
@torch.no_grad()
@torch.jit.script
def __pack_data(
jtparams: torch.Tensor, iter_dict: Dict[str, torch.Tensor], cast: int
) -> torch.Tensor:
"""
Jitted loop to pack the data into a flattened buffer to be sent
"""
st = 0
cast_type = torch.float if cast == 2 else torch.bfloat16 if cast == 0 else torch.half
for par in iter_dict.values():
if par.requires_grad:
# flatten and prep the data for sending
p = torch.flatten(par)
p = p.to(cast_type)
jtparams[st : st + par.numel()] = p
st += par.numel()
return jtparams
[docs]
def print0(self, *args, **kwargs) -> None:
"""
Print a message on rank 0 if the class parameter `verbose` is set.
"""
if self.comm.rank == 0 and self.verbose:
print(*args, **kwargs)
[docs]
def reset(self) -> None:
"""
Reset the optimizer to its base state
"""
self.stability.reset()
self.global_skip = 0
self.local_skip = 0
self.batches_to_wait = 0
self.current_batch = 0
self._prev_params = []
self.epoch = 0
self._gs8_waited = 0
self.zero_grad()
[docs]
def set_model(self, model: torch.nn.Module) -> None:
"""
Set the local model for the optimizer.
This should be called during the init of :func:`nn.DataParallelMultiGPU <heat.nn.data_parallel.DataParallelMultiGPU>`.
However, this can also be called manually.
Parameters
----------
model: torch.nn.Module
the local torch model.
"""
self.module = model
[docs]
def _start_local_sync(self) -> None:
"""
*Start* local synchronizations for the next batches
"""
if not isinstance(self.module, tDDP) or self.module.require_backward_grad_sync:
# this has no effect if the module is not locally distributed in torch
return
self.module.require_backward_grad_sync = True
[docs]
def step(self) -> None:
"""
Perform a single optimization step.
This will perform the `step` operations of the local optimizer,
local learning rate scheduler (if defined), and the gradient scaler used in automatic mixed
precision (if defined).
Also in the step is the logic used for when to send and receive the global/local synchronizations.
Global Syncs occur on batches for which the modulus of the batch number and the `global_skip` number is 0.
If `batches_to_wait` > 0, the next batches have only local syncs. After that number of batches,
the data during the global sync phase is received.
Local synchronization can also be turned off if desired by increasing `local_skips` above 1.
Notes
-----
self.last_batch must be set!
"""
if self.last_batch is None:
raise ValueError(
"self.last_batch must be set as the number of batches (len(dataloader))"
)
if self.amp:
self.scaler.step(self.local_optimizer)
# todo: add something to tell if the grads have infs or nans
# Updates the scale for next iteration.
self.scaler.update()
elif self.scheduler is None:
self.local_optimizer.step()
else:
self.scheduler.step()
batch = self.current_batch
# knowing next_batch is important to make sure that the local sync is on
# or if the next is the last batch
next_batch = batch + 1
gs = self.global_skip
ls = self.local_skip
# determine if to do the syncs
gmod = batch % gs if gs > 0 else 0
lmod = batch % ls if ls > 0 else 0
batches_to_wait = self.batches_to_wait
# ensure that the batch that will receive will be before the end of the training loop
btw = (
batches_to_wait
if batches_to_wait + batch <= self.last_batch
else self.last_batch - batch
)
# do full sync on global skips and on the last batch
if batch == self.last_batch or gmod == 0:
return self._global_sync(btw)
if next_batch % gs == 0:
self._start_local_sync()
self.current_batch += 1
return
if gmod < btw:
# do nothing on these batches (maintain the local sync)
self.current_batch += 1
if next_batch == self.last_batch:
self._start_local_sync()
return
elif gmod == btw:
# local updates should be on before this is called!
self._gs_rcv_update_params()
self._local_update(self._send_mod_m1)
if ls > 1:
self._stop_local_sync()
if ls == 1 and next_batch != self.last_batch:
self.current_batch += 1
self._start_local_sync()
return
if lmod == 0:
self._stop_local_sync()
elif next_batch % ls == 0:
self._start_local_sync()
if next_batch == self.last_batch:
self._start_local_sync()
self.current_batch += 1
[docs]
def _stop_local_sync(self) -> None:
"""
Stop local synchronizations for the next batches
"""
if not isinstance(self.module, tDDP) or not self.module.require_backward_grad_sync:
# this has no effect if the module is not locally distributed in torch
return
self.module.require_backward_grad_sync = False
[docs]
def zero_grad(self) -> None:
"""
Reset gradients of local optimizer's parameters.
"""
# reset view onto params in order to reset all gradients
self.local_optimizer.param_groups[0]["params"] = self.params_ref[:]
self.local_optimizer.zero_grad(set_to_none=False)
[docs]
class DataParallelOptimizer:
"""
Uses a torch.optim.Optimizer for data parallelism. It should be used in combination with DataParallel (DP) class.
To optimize a DP module, DP optimizer has to be passed to DP module during its initialization.
See :func:`nn.DataParallel <heat.nn.data_parallel.DataParallel>` for a basic example of usage.
Attributes
----------
torch_optimizer : torch.optim.Optimizer
the wrapped Torch optimizer
blocking : bool
use blocking communications or not. will typically be overwritten by :func:`nn.DataParallel <heat.nn.data_parallel.DataParallel>`
"""
def __init__(self, torch_optimizer: torch.optim.Optimizer, blocking: bool = False): # noqa: D107
self.torch_optimizer = torch_optimizer
if not isinstance(blocking, bool):
raise TypeError(f"blocking parameter must be a boolean, currently {type(blocking)}")
# flag indicating if communication during parameter updates is blocking.
self.blocking_parameter_updates = blocking
# flag indicating if optimizer should take a step during next iteration (only relevant for non-blocking)
self.update_next = False
# reference of optimizer's params
self.params_ref = torch_optimizer.param_groups[0]["params"]
[docs]
def step(self) -> None:
"""
Force torch optimizer to update model parameters. For blocking, optimizer immediately updates parameters. For
non-blocking, optimizer will update parameters during next forward.
"""
if self.blocking_parameter_updates:
self.torch_optimizer.step()
else:
self.update_next = True
[docs]
def zero_grad(self) -> None:
"""
Reset gradients of optimizer's params.
"""
# reset view onto params in order to reset all gradients
self.torch_optimizer.param_groups[0]["params"] = self.params_ref[:]
self.torch_optimizer.zero_grad(set_to_none=False)