Feat/optimizer (#194)

* feat(optimier.py): reduce memory footprint and avoid _check_overflow call

* feat(optimier.py): reduce memory footprint and avoid _check_overflow call

* feat(optimizer.py): overlap compute norm with allreduce

* update var and function name

* update function compute norm (#197)

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>

* feat(optimizer/hybrid_zero_optim.py): overlap gradients last bucket allreduce and compute norm (#196)

* support gradients allreduce and compute norm overlap

* fix para set error

* remove timer cal_norm for testing

* feat(optimizer/hybrid_zero_optim.py): support group global norm

* format(lint): fix lint error

* feat(optimizer/store.py): update code based on comment

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: huangting4201 <1538303371@qq.com>
pull/199/head
Sun Peng 2023-08-15 18:55:10 +08:00 committed by GitHub
parent 4e8bd39d8f
commit ef851d16c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 168 additions and 92 deletions

1
.gitignore vendored
View File

@ -115,6 +115,7 @@ venv.bak/
*.pkl
*.pkl.json
*.log.json
*.trace.json
docs/modelzoo_statistics.md
mmdet/.mim
work_dirs/

View File

@ -7,23 +7,25 @@ from torch.nn import init
from torch.nn.parameter import Parameter
def manual_rms_norm(input, normalized_shape, weight, eps):
def manual_rms_norm(my_input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
my_input = my_input * torch.rsqrt(variance + eps)
if weight is None:
return input
return my_input
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(weight.dtype)
my_input = my_input.to(weight.dtype)
return weight * input
return weight * my_input
class RMSNormTorch(torch.nn.Module):
"""A custom PyTorch module for RMS normalization."""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
@ -34,8 +36,8 @@ class RMSNormTorch(torch.nn.Module):
self.weight = Parameter(torch.empty(*normalized_shape))
self.reset_parameters()
def forward(self, input: torch.Tensor):
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
def forward(self, _input: torch.Tensor):
return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
def reset_parameters(self):
init.ones_(self.weight)

View File

@ -88,15 +88,17 @@ def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
assert input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), input)
def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
assert my_input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), my_input)
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
return grad_weight, grad_bias
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc):
"""A custom PyTorch module extending FusedDenseFunc."""
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
@ -173,8 +175,8 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
def symbolic(input_):
return _split(input_, parallel_mode=None)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):

View File

@ -178,6 +178,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if len(params) != 0:
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
for param in params:
setattr(param, "group_id", group_id)
self._param_store.set_param_to_rank(param, rank)
# move to cpu to make room to create the flat tensor
@ -317,7 +318,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_stored_in_bucket(reduce_rank)
self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
@ -335,7 +336,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
def _reduce_grads_stored_in_bucket(self, reduce_rank=None):
def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False):
# reduce grads
self._reduce_grads_by_rank(
reduce_rank=reduce_rank,
@ -343,30 +344,27 @@ class HybridZeroOptimizer(BaseOptimizer):
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
)
# use communication stream if overlapping
# communication with computation
if self._overlap_communication:
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
with torch.cuda.stream(stream):
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced = self._param_store.is_param_reduced(param)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = (
f"Parameter of size ({param.size()}) has been reduced, "
+ "duplicate reduction will lead to arithmetic incorrectness"
)
raise RuntimeError(msg)
if is_param_reduced:
msg = (
f"Parameter of size ({param.size()}) has been reduced, "
+ "duplicate reduction will lead to arithmetic incorrectness"
)
raise RuntimeError(msg)
# update the flag
self._param_store.set_param_reduction_state(param, True)
# update the flag
self._param_store.set_param_reduction_state(param, True)
if self._param_store.belongs_to_current_rank(param):
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
else:
self._param_store.add_previous_reduced_param(param)
self._bucket_store.reset_by_rank(reduce_rank)
@ -385,9 +383,9 @@ class HybridZeroOptimizer(BaseOptimizer):
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
else:
stream = torch.cuda.current_stream()
@ -421,6 +419,7 @@ class HybridZeroOptimizer(BaseOptimizer):
reduction_states = self._param_store.get_param_reduction_states()
for tensor, _ in reduction_states.items():
reduction_states[tensor] = False
self._param_store.reset_reduced_data_for_compute_norm()
# accumulate gradient
avg_gradients = self._grad_store._averaged_gradients
@ -469,6 +468,30 @@ class HybridZeroOptimizer(BaseOptimizer):
# Gradients may not be fully synchronized here.
def _compute_norm_with_stage(
self,
group_id: int = 0,
last_bucket: bool = False,
last_stage: bool = False,
previous_norm=None,
):
# compute norm for gradients that have been reduced
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
if len(params) == 0:
grads = [self.padding_grad]
params = [self.padding_tensor]
if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large
norm = compute_norm(
gradients=grads,
parameters=params,
last_stage=last_stage,
previous_norm=previous_norm,
)
return norm
def step(self, closure=None):
"""Performs a single optimization step.
@ -480,7 +503,6 @@ class HybridZeroOptimizer(BaseOptimizer):
"""
assert closure is None, "closure is not supported by step()"
timer("sync_grad").start()
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
if not self._overlap_communication:
@ -490,54 +512,49 @@ class HybridZeroOptimizer(BaseOptimizer):
self._store_and_try_reduce_grads_by_bucket(param)
# we need to reduce the gradients left in the communication bucket
self._reduce_grads_stored_in_bucket()
self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True)
# compute norm for gradients in the before bucket
groups_norms = []
for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
# grads in the last bucket is reduced
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# compute norm for gradients in the last bucket
total_norms = []
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
)
)
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()
return self._step(closure=closure)
return self._step(closure=closure, norms=total_norms)
def _step(self, closure=None):
def _step(self, closure=None, norms=None):
assert closure is None, "closure is not supported by step()"
# check for overflow
found_inf = self._check_overflow()
found_inf = False
# if there is INF values in grades, compute_norm func would also returns -1
# thus, we try to avoid call _check_overflow here
# found_inf = self._check_overflow()
# Because you may encounter inf when computing norm
timer("cal_norm").start()
norm_groups = []
for group_id in range(self.num_param_groups):
# compute norm
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
parameters = self._param_store.get_fp16_params_by_rank_group(
group_id=group_id, rank=self._zero_local_rank
)
else:
# in order to prevent collection communication from hanging,
# we need to involve rank that are not assigned parameters in compute_norm(),
# so we give them a fp16 vector of 0 values.
gradients = [self.padding_grad]
parameters = [self.padding_tensor]
if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large
norm_group = compute_norm(
gradients=gradients,
parameters=parameters,
)
if norm_group == -1:
timer("cal_norm").stop()
found_inf = True
break
norm_groups.append(norm_group)
if -1 in norms:
found_inf = True
loss_scale = float(self.loss_scale.item()) # backup
if not gpc.config.model.dtype is torch.float32:
if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
@ -550,7 +567,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
global_norm = 0
for group_id in range(self.num_param_groups):
# compute norm
# The following operations are performed only on the rank to which parameters are assigned.
@ -559,13 +575,14 @@ class HybridZeroOptimizer(BaseOptimizer):
# create flat gradient for the flat fp32 params
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
flat_fp16_avg_grads = flatten(gradients)
with torch.no_grad():
flat_fp16_avg_grads = flatten(gradients)
self._grad_store.reset_average_gradients_by_group(group_id)
del gradients # release cuda memory
gradients = None # release cuda memory
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
del flat_fp16_avg_grads # release cuda memory
flat_fp16_avg_grads = None # release cuda memory
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
assert (
@ -578,15 +595,16 @@ class HybridZeroOptimizer(BaseOptimizer):
# unscale and clip grads
# get the global norm
global_norm_groups = []
if self._clip_grad_norm > 0:
global_norm = sum(norm_groups) ** 0.5
for norm in norms:
global_norm_groups.append(norm**0.5)
# the following operations are performed only on the rank to which parameters are assigned.
if not gpc.config.model.dtype is torch.float32:
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
timer("cal_norm").stop()
# update the parameters
timer("step").start()
@ -611,7 +629,7 @@ class HybridZeroOptimizer(BaseOptimizer):
timer("step").stop()
# update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained
return True, global_norm / loss_scale
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
def broadcast_params(self, overlap=False):
handles = []
@ -655,18 +673,20 @@ class HybridZeroOptimizer(BaseOptimizer):
return self._found_overflow.item() > 0
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm, loss_scale):
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
# compute combined scale factor for this group
combined_scale = loss_scale
combined_scale_groups = []
if self._clip_grad_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1.0:
combined_scale = clip * loss_scale
for group_id, total_norm in enumerate(total_norm_groups):
combined_scale_groups.append(loss_scale)
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1.0:
combined_scale_groups[group_id] = clip * loss_scale
for grad in grad_groups_flat:
grad.data.mul_(1.0 / combined_scale)
for group_id, grad in enumerate(grad_groups_flat):
grad.data.mul_(1.0 / combined_scale_groups[group_id])
def clip_grad_norm(self, model, max_norm):
# will conduct in the step()

View File

@ -152,6 +152,11 @@ class ParameterStore(BaseStore):
self._is_param_reduced = dict()
self._reduced_param = []
self._former_bucket_reduced_param = {}
self._last_bucket_reduced_param = {}
self._former_bucket_reduced_grad = {}
self._last_bucket_reduced_grad = {}
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
"""
Set the mapping between parameter to rank, each parameter should be owned by a rank.
@ -223,6 +228,35 @@ class ParameterStore(BaseStore):
def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor)
def add_reduced_param_for_compute_norm(self, param, last_bucket=False):
group_id = getattr(param, "group_id")
if last_bucket:
if group_id not in self._last_bucket_reduced_param:
self._last_bucket_reduced_param[group_id] = []
self._last_bucket_reduced_grad[group_id] = []
self._last_bucket_reduced_param[group_id].append(param)
self._last_bucket_reduced_grad[group_id].append(param.grad)
else:
if group_id not in self._former_bucket_reduced_param:
self._former_bucket_reduced_param[group_id] = []
self._former_bucket_reduced_grad[group_id] = []
self._former_bucket_reduced_param[group_id].append(param)
self._former_bucket_reduced_grad[group_id].append(param.grad)
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
if not last_bucket:
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
else:
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
def reset_reduced_data_for_compute_norm(self):
self._former_bucket_reduced_param = {}
self._last_bucket_reduced_param = {}
self._former_bucket_reduced_grad = {}
self._last_bucket_reduced_grad = {}
def clear_grads_of_previous_reduced_params(self):
if len(self._reduced_param) > 0:
for param in self._reduced_param:

View File

@ -21,6 +21,7 @@ logger = get_logger(__file__)
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
APEX_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
@ -162,6 +163,7 @@ def sync_param(flat_tensor, tensor_list):
for p, q in zip(tensor_list, updated_params):
p.data = q.data
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
# Convert tensor_list elements to torch.float32
tensor_list = [tensor.float() for tensor in tensor_list]
@ -175,6 +177,7 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor):
return l2_norm, per_tensor_norm
def calc_l2_norm(grads):
norm = 0.0
if len(grads) > 0:
@ -187,6 +190,7 @@ def calc_l2_norm(grads):
norm, _ = multi_tensor_l2norm_torch(grads, False)
return norm
def calc_lp(grads, norm_type):
norm = 0.0
for grad in grads:
@ -195,7 +199,7 @@ def calc_lp(grads, norm_type):
return norm
def compute_norm(gradients, parameters, norm_type=2):
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2):
"""Get the norm
Arguments:
gradients (Iterable[Tensor]): The gradient value.
@ -215,6 +219,13 @@ def compute_norm(gradients, parameters, norm_type=2):
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
if last_stage is False:
return total_norm_cuda
if previous_norm is not None:
total_norm_cuda = max(total_norm_cuda, previous_norm)
# Take max across all model-parallel GPUs.
if gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
@ -261,6 +272,12 @@ def compute_norm(gradients, parameters, norm_type=2):
total_norm = tensor_parallel_norm
if last_stage is False:
return total_norm
if previous_norm is not None:
total_norm = total_norm + previous_norm
# Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))

View File

@ -7,6 +7,7 @@ import traceback
from functools import partial
from typing import Iterable
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
@ -603,12 +604,12 @@ def main(args):
trainer_result = trainer.step()
assert trainer_result is not None
success_update, grad_norm = trainer_result
success_update, grad_norm_groups = trainer_result
if success_update: # update parameters successfully
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
@ -628,7 +629,7 @@ def main(args):
trainer=trainer,
start_time=start_time,
loss=loss,
grad_norm=grad_norm,
grad_norm=np.array(grad_norm_groups),
metric=metric,
update_panel=uniscale_logger is not None,
)
@ -668,7 +669,6 @@ if __name__ == "__main__":
main(args)
except Exception:
logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
exc_info=traceback.format_exc(),
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
)
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())