ColossalAI/colossalai/zero/low_level/low_level_optim.py

602 lines
26 KiB
Python

# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from functools import partial
from typing import Optional
import torch
import torch.distributed as dist
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
MixedPrecisionMixin,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
from ._utils import (
calculate_global_norm_from_list,
compute_norm,
flatten,
has_inf_or_nan,
reduce_tensor_dp_group,
release_param_grad,
split_by_dtype,
sync_param,
)
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
num_working_param_groups: int,
grad_store: GradientStore,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store
def check_local_overflow(self) -> bool:
for group_id in range(self.num_working_param_groups):
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
return True
return False
class LowLevelZeroOptimizer(ColossalaiOptimizer):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
def __init__(
self,
optimizer: Optimizer,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
backoff_factor: float = .5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None):
# TODO: add support for
# 1. fp16 master weights
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
# 5. support layer drop
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
self._verbose = verbose
# stage 2
self._partition_grads = partition_grad
self._cpu_offload = cpu_offload
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
self._local_rank = colo_pg.dp_local_rank()
self._world_size = colo_pg.dp_world_size()
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
self._dp_torch_group = colo_pg.dp_process_group()
self._mp_torch_group = None
if colo_pg.tp_world_size() > 1:
self._mp_torch_group = colo_pg.tp_process_group()
elif colo_pg is None:
dp_parallel_mode = ParallelMode.DATA
mp_parallel_mode = ParallelMode.MODEL
self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
self._mp_torch_group = None
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
else:
raise NotImplementedError
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_flat_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient clipping
self._clip_grad_norm = clip_grad_norm
if forced_dtype:
for group in self.optim.param_groups:
group_params = group['params']
for param in group_params:
param.data = param.data.to(forced_dtype)
self._dtype = forced_dtype
# check argument conflict
self._sanity_checks()
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self._dp_torch_group)
self._grad_store = GradientStore(self._dp_torch_group)
self._bucket_store = BucketStore(self._dp_torch_group)
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
group_params = list()
for param in param_group['params']:
if param.requires_grad:
group_params.append(param)
# add the working params to working_param_groups for bookkeeping
self._working_param_groups[group_id] = group_params
# assign parameters to ranks
# the params in the list are sorted
params_per_rank = self._partition_param_list(group_params)
# store the mapping between param to rank
# each param should belong to only one rank
for rank, params in enumerate(params_per_rank):
self._param_store.add_param_list_by_rank_group(rank, group_id, params)
for param in params:
self._param_store.set_param_to_rank(param, rank)
# move to cpu to make room to create the flat tensor
# move_tensor(params, device='cpu')
for param in group_params:
param.data = param.data.cpu()
# flatten the reordered tensors
for rank in range(self._world_size):
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
with torch.no_grad():
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.data.cuda()
self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor)
# sync parameters
for rank in range(self._world_size):
flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id)
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
# create a copy of fp32 master weights of the parameters for which this rank is responsible
working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id)
master_flat_current_rank = working_flat_current_rank.float()
device = 'cpu' if self._cpu_offload else get_current_device()
master_flat_current_rank = master_flat_current_rank.to(device)
master_flat_current_rank.requires_grad = True
self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group['params'] = [master_flat_current_rank]
# set reduction state
for param in self._working_param_groups[group_id]:
self._param_store.set_param_reduction_state(param, False)
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
# initialize mixed precision mixin
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
self._grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
@property
def dtype(self):
return self._dtype
@property
def num_param_groups(self):
return len(self._working_param_groups)
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
def _search_colo_process_group(self):
colo_flag = False
colo_pg = None
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
if isinstance(param, ColoParameter):
colo_flag = True
if colo_pg is None:
colo_pg = param.get_process_group()
else:
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
elif colo_flag:
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg
def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
# partititon the parameters in a greedy fashion
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
for param in sorted_params:
# allocate this parameter to the rank with
# the smallest numel for load balancing purpose
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
numel_per_rank[rank_to_go] += param.numel()
if self._verbose:
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank
###########################
# Backward Reduction Hook #
###########################
def _grad_handler(self, param, grad, reduce_rank):
self._add_to_reduction_bucket(param, reduce_rank)
return grad
def _attach_reduction_hook(self):
# we iterate over the working params
# on each param, we register a hook to its AccumulateGrad object
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
# determines the reduction destionation rank
# this is only valid for stage 2
# dst_rank = None means using all-reduce
# else using reduce
if self._partition_grads:
reduce_rank = self._param_store.get_param_rank(param)
else:
reduce_rank = None
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_grads(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_by_dtype(grads)
for tensor_list in grad_buckets_by_dtype:
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
bucket_size=bucket_size,
reduce_rank=reduce_rank)
#######################
# Reduction Functions #
#######################
def _run_reduction(self, reduce_rank=None):
# reduce grads
self._reduce_grads(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
# use communication stream if overlapping
# communication with computation
if self._overlap_communication:
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
# update the flag
self._param_store.set_param_reduction_state(param, True)
# if partition grads = True
# we do not keep the gradient after reduction
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
if self._overlap_communication:
# we need to keep this gradient for now as reduction may
# be completed yet since it is using a different cuda stream
self._param_store.add_previous_reduced_param(param)
else:
param.grad = None
self._bucket_store.reset_by_rank(reduce_rank)
def _add_to_reduction_bucket(self, param, reduce_rank=None):
param_size = param.numel()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._run_reduction(reduce_rank)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
################################
# torch.optim.Optimizer methods
################################
def backward(self, loss, retain_graph=False, sync_grad=True):
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)
# finish gradient reduction
if not self._partition_grads:
self._reduce_grad_stage1()
else:
# TODO: support async comm in reduce
self._reduce_grad_stage2()
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# gradient synchronization
if sync_grad:
self._sync_grad()
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
will be set to None to save memory.
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
if self.mixed_precision_mixin is not None:
self.mixed_precision_mixin.pre_zero_grad()
for _, param_group in self._working_param_groups.items():
for param in param_group:
if set_to_none:
param.grad = None
else:
if param.grad is not None:
param.grad.detach()
param.grad.zero_()
####################
# Update Parameter #
####################
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_average_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
return
# copy the grad of working param to master param
single_grad_partition_groups = []
norm_groups = []
for group_id in range(self.num_param_groups):
# compute norm
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
params=self._param_store.get_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group)
norm_groups.append(norm_group)
# create flat gradient for the flat fp32 master params
working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
flat_working_avg_grads = flatten(working_avg_grads)
dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype
flat_master_avg_grads = flat_working_avg_grads.to(dtype)
param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape
assert param_shape == flat_master_avg_grads.shape, \
f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}'
single_grad_partition_groups.append(flat_master_avg_grads)
device = self._master_flat_param_groups_of_current_rank[group_id].device
self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device)
self._grad_store.reset_average_gradients_by_group(group_id)
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
# update the parameters
self.optim.step()
# release the master grad
release_param_grad(self._master_flat_param_groups_of_current_rank.values())
# update working partition updated by the current rank
for group_id in range(len(self._working_param_groups)):
working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id)
master_param = self._master_flat_param_groups_of_current_rank[group_id]
working_param.data.copy_(master_param)
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
for index in range(self._world_size):
rank = self._dp_global_ranks[index]
working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
#############################
# Mixed Precision Utilities #
#############################
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
div_scale = 1.0
if self.mixed_precision_mixin is not None:
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
if self._clip_grad_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
div_scale = clip * div_scale
for grad in grad_groups_flat:
grad.data.mul_(1. / div_scale)
############################
# Gradient Synchronization #
############################
def _sync_grad(self):
# update param already reduced flag
reduction_states = self._param_store.get_param_reduction_states()
for tensor, _ in reduction_states.items():
reduction_states[tensor] = False
# accumulate gradient
for group_id in range(self.num_param_groups):
param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id)
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
param_idx = 0
for param in param_group:
if param.grad is not None:
if len(avg_gradients_group) == param_idx:
self._grad_store.append_average_gradient_by_group(group_id, param.grad)
else:
self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
param_idx += 1
# the gradients needed are stored in the avg_gradients buffer
# thus, can clear this
self.zero_grad()
def _reduce_grad_stage1(self):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
if not self._overlap_communication:
for group_id in range(len(self._working_param_groups)):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._add_to_reduction_bucket(param)
# we need to reduce the gradients
# left in the communication bucket
self._run_reduction()
def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks
# are attached in the __init__ function, so we
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
self._run_reduction(reduce_rank)