mirror of https://github.com/hpcaitech/ColossalAI
592 lines
25 KiB
Python
592 lines
25 KiB
Python
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
|
from functools import partial
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.optim import Optimizer
|
|
|
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
|
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
from colossalai.tensor import ColoParameter, ProcessGroup
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from ._utils import (
|
|
calculate_global_norm_from_list,
|
|
compute_norm,
|
|
flatten,
|
|
has_inf_or_nan,
|
|
reduce_tensor_dp_group,
|
|
release_param_grad,
|
|
split_half_float_double,
|
|
sync_param,
|
|
)
|
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
|
|
|
|
|
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
optimizer: Optimizer,
|
|
initial_scale: int = 2**16, # grad scaler config
|
|
min_scale: int = 1,
|
|
growth_factor: float = 2.,
|
|
backoff_factor: float = .5,
|
|
growth_interval: int = 2000,
|
|
hysteresis: int = 2,
|
|
max_scale: int = 2**24,
|
|
clip_grad_norm: float = 0.0, # grad clipping
|
|
verbose: bool = False,
|
|
reduce_bucket_size: int = 1024 * 1024, # communication
|
|
communication_dtype: Optional[torch.dtype] = None,
|
|
overlap_communication: bool = False,
|
|
partition_grad: bool = False, # stage 2 flag
|
|
cpu_offload: bool = False, # cpu offload
|
|
forced_dtype: Optional[torch.dtype] = None):
|
|
|
|
# TODO: add support for
|
|
# 1. fp16 master weights
|
|
# 2. contiguous gradients
|
|
# 3. cpu offload
|
|
# 4. support when some parameters requires_grad = False
|
|
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
|
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
|
self._logger = get_dist_logger()
|
|
self._verbose = verbose
|
|
|
|
# stage 2
|
|
self._partition_grads = partition_grad
|
|
|
|
self._cpu_offload = cpu_offload
|
|
|
|
colo_pg = self._search_colo_process_group()
|
|
if isinstance(colo_pg, ProcessGroup):
|
|
self._local_rank = colo_pg.dp_local_rank()
|
|
self._world_size = colo_pg.dp_world_size()
|
|
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
|
|
self._dp_torch_group = colo_pg.dp_process_group()
|
|
self._mp_torch_group = None
|
|
if colo_pg.tp_world_size() > 1:
|
|
self._mp_torch_group = colo_pg.tp_process_group()
|
|
elif colo_pg is None:
|
|
dp_parallel_mode = ParallelMode.DATA
|
|
mp_parallel_mode = ParallelMode.MODEL
|
|
|
|
self._dp_parallel_mode = dp_parallel_mode
|
|
self._mp_parallel_mode = mp_parallel_mode
|
|
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
|
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
|
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
|
|
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
|
|
self._mp_torch_group = None
|
|
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
|
|
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
|
|
else:
|
|
raise NotImplementedError
|
|
# fp16 and fp32 params for mixed precision training
|
|
self._fp16_param_groups = dict()
|
|
self._fp32_flat_param_groups_of_current_rank = dict()
|
|
|
|
# communication params
|
|
self._overlap_communication = overlap_communication
|
|
self._reduce_bucket_size = reduce_bucket_size
|
|
self._communication_dtype = communication_dtype
|
|
|
|
# gradient scaler
|
|
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
|
min_scale=min_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval,
|
|
hysteresis=hysteresis,
|
|
max_scale=max_scale,
|
|
verbose=verbose)
|
|
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
|
|
|
|
# gradient clipping
|
|
self._clip_grad_norm = clip_grad_norm
|
|
|
|
if forced_dtype:
|
|
for group in self.optim.param_groups:
|
|
group_params = group['params']
|
|
for param in group_params:
|
|
param.data = param.data.to(forced_dtype)
|
|
self._dtype = forced_dtype
|
|
|
|
# check argument conflict
|
|
self._sanity_checks()
|
|
|
|
# ParameterStore will manage the tensor buffers used for zero
|
|
# it will not manage the tensors used by mixed precision training
|
|
self._param_store = ParameterStore(self._dp_torch_group)
|
|
self._grad_store = GradientStore(self._dp_torch_group)
|
|
self._bucket_store = BucketStore(self._dp_torch_group)
|
|
|
|
# iterate over the param group in the optimizer
|
|
# partition these param groups for data parallel training
|
|
# and add buffers to parameter store for future access
|
|
for group_id, param_group in enumerate(self.optim.param_groups):
|
|
group_params = list()
|
|
for param in param_group['params']:
|
|
if param.requires_grad:
|
|
group_params.append(param)
|
|
|
|
# add the fp16 params to fp16_param_groups for bookkeeping
|
|
self._fp16_param_groups[group_id] = group_params
|
|
|
|
# assign parameters to ranks
|
|
# the params in the list are sorted
|
|
params_per_rank = self._partition_param_list(group_params)
|
|
|
|
# store the mapping between param to rank
|
|
# each param should belong to only one rank
|
|
for rank, params in enumerate(params_per_rank):
|
|
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
|
for param in params:
|
|
self._param_store.set_param_to_rank(param, rank)
|
|
|
|
# move to cpu to make room to create the flat tensor
|
|
# move_tensor(params, device='cpu')
|
|
for param in group_params:
|
|
param.data = param.data.cpu()
|
|
|
|
# flatten the reordered tensors
|
|
for rank in range(self._world_size):
|
|
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
|
with torch.no_grad():
|
|
flat_tensor = flatten(tensor_list)
|
|
flat_tensor = flat_tensor.data.cuda()
|
|
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
|
|
|
# sync parameters
|
|
for rank in range(self._world_size):
|
|
flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id)
|
|
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
|
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
|
|
|
|
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
|
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
|
|
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
|
device = 'cpu' if self._cpu_offload else get_current_device()
|
|
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
|
fp32_flat_current_rank.requires_grad = True
|
|
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
|
|
|
# need to replace the params in the `params` field in the optimizer
|
|
# so that when the optimizer calls step(), it only updates the tensors
|
|
# managed by this data parallel rank
|
|
param_group['params'] = [fp32_flat_current_rank]
|
|
|
|
# set reduction state
|
|
for param in self._fp16_param_groups[group_id]:
|
|
self._param_store.set_param_reduction_state(param, False)
|
|
|
|
# intialize communication stream for
|
|
# communication-compuation overlapping
|
|
if self._overlap_communication:
|
|
self._comm_stream = torch.cuda.Stream()
|
|
|
|
# reduction hook is only used if overlapping communication
|
|
# or stage 2 is used
|
|
# if it is stage 1 without overlapping, no hook will be attached
|
|
if self._overlap_communication or self._partition_grads:
|
|
self._attach_reduction_hook()
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self._dtype
|
|
|
|
@property
|
|
def loss_scale(self):
|
|
return self.grad_scaler.scale
|
|
|
|
@property
|
|
def num_param_groups(self):
|
|
return len(self._fp16_param_groups)
|
|
|
|
def _sanity_checks(self):
|
|
assert torch.cuda.is_available(), 'CUDA is required'
|
|
for param_group in self.optim.param_groups:
|
|
group_params = param_group['params']
|
|
for param in group_params:
|
|
assert param.dtype == self._dtype, \
|
|
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
|
|
|
def _search_colo_process_group(self):
|
|
colo_flag = False
|
|
colo_pg = None
|
|
for param_group in self.optim.param_groups:
|
|
group_params = param_group['params']
|
|
for param in group_params:
|
|
if isinstance(param, ColoParameter):
|
|
colo_flag = True
|
|
if colo_pg is None:
|
|
colo_pg = param.get_process_group()
|
|
else:
|
|
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
|
|
elif colo_flag:
|
|
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
|
|
return colo_pg
|
|
|
|
def _partition_param_list(self, param_list):
|
|
params_per_rank = [[] for _ in range(self._world_size)]
|
|
numel_per_rank = [0 for _ in range(self._world_size)]
|
|
|
|
# partititon the parameters in a greedy fashion
|
|
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
|
for param in sorted_params:
|
|
# allocate this parameter to the rank with
|
|
# the smallest numel for load balancing purpose
|
|
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
|
params_per_rank[rank_to_go].append(param)
|
|
numel_per_rank[rank_to_go] += param.numel()
|
|
|
|
if self._verbose:
|
|
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
|
|
return params_per_rank
|
|
|
|
###########################
|
|
# Backward Reduction Hook #
|
|
###########################
|
|
|
|
def _grad_handler(self, param, grad, reduce_rank):
|
|
self._add_to_reduction_bucket(param, reduce_rank)
|
|
return grad
|
|
|
|
def _attach_reduction_hook(self):
|
|
# we iterate over the fp16 params
|
|
# on each param, we register a hook to its AccumulateGrad object
|
|
for group_id in range(self.num_param_groups):
|
|
param_group = self._fp16_param_groups[group_id]
|
|
for param in param_group:
|
|
if param.requires_grad:
|
|
# determines the reduction destionation rank
|
|
# this is only valid for stage 2
|
|
# dst_rank = None means using all-reduce
|
|
# else using reduce
|
|
if self._partition_grads:
|
|
reduce_rank = self._param_store.get_param_rank(param)
|
|
else:
|
|
reduce_rank = None
|
|
|
|
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
|
|
|
|
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
|
|
if self._overlap_communication:
|
|
torch.cuda.synchronize()
|
|
self._param_store.clear_grads_of_previous_reduced_params()
|
|
stream = self._comm_stream
|
|
else:
|
|
stream = torch.cuda.current_stream()
|
|
|
|
with torch.cuda.stream(stream):
|
|
flat = bucket.flatten()
|
|
reduce_global_rank = None
|
|
if reduce_rank is not None:
|
|
reduce_global_rank = self._dp_global_ranks[reduce_rank]
|
|
reduced_flat = reduce_tensor_dp_group(tensor=flat,
|
|
dtype=self._communication_dtype,
|
|
dst_local_rank=reduce_rank,
|
|
dst_global_rank=reduce_global_rank,
|
|
group=self._dp_torch_group)
|
|
|
|
# update the reduced tensor
|
|
if reduce_rank is None or reduce_rank == self._local_rank:
|
|
bucket.unflatten_and_copy(reduced_flat)
|
|
|
|
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
|
|
param_bucket = TensorBucket(size=bucket_size)
|
|
|
|
for tensor in tensor_list:
|
|
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
|
|
|
if param_bucket.is_full_or_oversized():
|
|
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
|
|
param_bucket.empty()
|
|
|
|
if not param_bucket.is_empty():
|
|
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
|
|
|
|
def _reduce_grads(self, reduce_rank, grads, bucket_size):
|
|
grad_buckets_by_dtype = split_half_float_double(grads)
|
|
|
|
for tensor_list in grad_buckets_by_dtype:
|
|
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
|
|
bucket_size=bucket_size,
|
|
reduce_rank=reduce_rank)
|
|
|
|
#######################
|
|
# Reduction Functions #
|
|
#######################
|
|
|
|
def _run_reduction(self, reduce_rank=None):
|
|
# reduce grads
|
|
self._reduce_grads(reduce_rank=reduce_rank,
|
|
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
|
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
|
|
|
# use communication stream if overlapping
|
|
# communication with computation
|
|
if self._overlap_communication:
|
|
stream = self._comm_stream
|
|
else:
|
|
stream = torch.cuda.current_stream()
|
|
|
|
with torch.cuda.stream(stream):
|
|
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
|
|
|
for param in params_in_bucket:
|
|
# the is_param_reduced flag should be False showing that
|
|
# this param is not reduced before calling self._reduce_grads_by_rank
|
|
is_param_reduced = self._param_store.is_param_reduced(param)
|
|
|
|
if is_param_reduced:
|
|
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
|
|
'duplicate reduction will lead to arithmetic incorrectness'
|
|
raise RuntimeError(msg)
|
|
|
|
# update the flag
|
|
self._param_store.set_param_reduction_state(param, True)
|
|
|
|
# if partition grads = True
|
|
# we do not keep the gradient after reduction
|
|
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
|
|
if self._overlap_communication:
|
|
# we need to keep this gradient for now as reduction may
|
|
# be completed yet since it is using a different cuda stream
|
|
self._param_store.add_previous_reduced_param(param)
|
|
else:
|
|
param.grad = None
|
|
|
|
self._bucket_store.reset_by_rank(reduce_rank)
|
|
|
|
def _add_to_reduction_bucket(self, param, reduce_rank=None):
|
|
param_size = param.numel()
|
|
|
|
# check if the bucket is full
|
|
# if full, will reduce the grads already in the bucket
|
|
# after reduction, the bucket will be empty
|
|
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
|
self._run_reduction(reduce_rank)
|
|
|
|
# the param must not be reduced to ensure correctness
|
|
is_param_reduced = self._param_store.is_param_reduced(param)
|
|
if is_param_reduced:
|
|
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
|
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
|
raise RuntimeError(msg)
|
|
|
|
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
|
self._bucket_store.add_param(param, reduce_rank)
|
|
|
|
################################
|
|
# torch.optim.Optimizer methods
|
|
################################
|
|
|
|
def backward(self, loss, retain_graph=False, sync_grad=True):
|
|
loss = self.loss_scale * loss
|
|
loss.backward(retain_graph=retain_graph)
|
|
|
|
# finish gradient reduction
|
|
if not self._partition_grads:
|
|
self._reduce_grad_stage1()
|
|
else:
|
|
# TODO: support async comm in reduce
|
|
self._reduce_grad_stage2()
|
|
|
|
# clear reduced grads
|
|
if self._overlap_communication:
|
|
torch.cuda.synchronize()
|
|
self._param_store.clear_grads_of_previous_reduced_params()
|
|
|
|
# gradient synchronization
|
|
if sync_grad:
|
|
self._sync_grad()
|
|
|
|
def zero_grad(self, set_to_none=True):
|
|
"""
|
|
Set parameter gradients to zero. If set_to_none = True, gradient
|
|
will be set to None to save memory.
|
|
|
|
:param set_to_none: Whether set the gradient to None. Default value is True.
|
|
:type set_to_none: bool
|
|
"""
|
|
for _, param_group in self._fp16_param_groups.items():
|
|
for param in param_group:
|
|
if set_to_none:
|
|
param.grad = None
|
|
else:
|
|
if param.grad is not None:
|
|
param.grad.detach()
|
|
param.grad.zero_()
|
|
|
|
####################
|
|
# Update Parameter #
|
|
####################
|
|
|
|
def step(self, closure=None):
|
|
assert closure is None, 'closure is not supported by step()'
|
|
|
|
# check for overflow
|
|
found_inf = self._check_overflow()
|
|
self.grad_scaler.update(found_inf)
|
|
|
|
# update loss scale if overflow occurs
|
|
if found_inf:
|
|
self._grad_store.reset_all_average_gradients()
|
|
self.zero_grad()
|
|
return
|
|
|
|
# copy the grad of fp16 param to fp32 param
|
|
single_grad_partition_groups = []
|
|
norm_groups = []
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
# compute norm
|
|
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
|
|
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
|
rank=self._local_rank),
|
|
dp_group=self._dp_torch_group,
|
|
mp_group=self._mp_torch_group)
|
|
norm_groups.append(norm_group)
|
|
|
|
# create flat gradient for the flat fp32 params
|
|
fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
|
|
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
|
|
|
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
|
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
|
|
|
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
|
assert param_shape == flat_fp32_avg_grads.shape, \
|
|
f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}'
|
|
|
|
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
|
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
|
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
|
self._grad_store.reset_average_gradients_by_group(group_id)
|
|
|
|
# unscale and clip grads
|
|
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
|
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
|
|
|
|
# update the parameters
|
|
self.optim.step()
|
|
# release the fp32 grad
|
|
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
|
|
|
# update fp16 partition updated by the current rank
|
|
for group_id in range(len(self._fp16_param_groups)):
|
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
|
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
|
fp16_param.data.copy_(fp32_param)
|
|
|
|
# broadcast the updated model weights
|
|
handles = []
|
|
for group_id in range(self.num_param_groups):
|
|
for index in range(self._world_size):
|
|
rank = self._dp_global_ranks[index]
|
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
|
|
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
|
|
handles.append(handle)
|
|
|
|
for handle in handles:
|
|
handle.wait()
|
|
|
|
##################
|
|
# FP16 Utilities #
|
|
##################
|
|
|
|
def _check_overflow(self):
|
|
# clear previous overflow record
|
|
self._found_overflow.fill_(0.0)
|
|
|
|
# check for overflow
|
|
for group_id in range(len(self._fp16_param_groups)):
|
|
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
|
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
|
self._found_overflow.fill_(1.0)
|
|
break
|
|
|
|
# all-reduce across dp group
|
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
|
|
|
|
# all-reduce over model parallel group
|
|
if self._mp_torch_group:
|
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
|
|
|
|
if self._found_overflow.item() > 0:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
|
# compute combined scale factor for this group
|
|
combined_scale = self.loss_scale
|
|
|
|
if self._clip_grad_norm > 0.:
|
|
# norm is in fact norm*scale
|
|
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
|
|
if clip > 1:
|
|
combined_scale = clip * self.loss_scale
|
|
|
|
for grad in grad_groups_flat:
|
|
grad.data.mul_(1. / combined_scale)
|
|
|
|
############################
|
|
# Gradient Synchronization #
|
|
############################
|
|
|
|
def _sync_grad(self):
|
|
# update param already reduced flag
|
|
reduction_states = self._param_store.get_param_reduction_states()
|
|
for tensor, _ in reduction_states.items():
|
|
reduction_states[tensor] = False
|
|
|
|
# accumulate gradient
|
|
for group_id in range(self.num_param_groups):
|
|
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
|
|
|
|
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
|
|
|
|
param_idx = 0
|
|
for param in param_group:
|
|
if param.grad is not None:
|
|
if len(avg_gradients_group) == param_idx:
|
|
self._grad_store.append_average_gradient_by_group(group_id, param.grad)
|
|
else:
|
|
self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
|
|
param_idx += 1
|
|
|
|
# the gradients needed are stored in the avg_gradients buffer
|
|
# thus, can clear this
|
|
self.zero_grad()
|
|
|
|
def _reduce_grad_stage1(self):
|
|
# if not overlapping communication (no reduction hook is attached)
|
|
# we need to manually reduce these gradients
|
|
if not self._overlap_communication:
|
|
for group_id in range(len(self._fp16_param_groups)):
|
|
param_group = self._fp16_param_groups[group_id]
|
|
for param in param_group:
|
|
if param.grad is not None:
|
|
self._add_to_reduction_bucket(param)
|
|
|
|
# we need to reduce the gradients
|
|
# left in the communication bucket
|
|
self._run_reduction()
|
|
|
|
def _reduce_grad_stage2(self):
|
|
# when partition_grads is True, reduction hooks
|
|
# are attached in the __init__ function, so we
|
|
# only need to reduce the gradients
|
|
# left in the communication bucket
|
|
for reduce_rank in range(self._world_size):
|
|
self._run_reduction(reduce_rank)
|