mirror of https://github.com/InternLM/InternLM
Feat/optimizer (#194)
* feat(optimier.py): reduce memory footprint and avoid _check_overflow call * feat(optimier.py): reduce memory footprint and avoid _check_overflow call * feat(optimizer.py): overlap compute norm with allreduce * update var and function name * update function compute norm (#197) Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> * feat(optimizer/hybrid_zero_optim.py): overlap gradients last bucket allreduce and compute norm (#196) * support gradients allreduce and compute norm overlap * fix para set error * remove timer cal_norm for testing * feat(optimizer/hybrid_zero_optim.py): support group global norm * format(lint): fix lint error * feat(optimizer/store.py): update code based on comment --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> Co-authored-by: huangting4201 <1538303371@qq.com>pull/199/head
parent
4e8bd39d8f
commit
ef851d16c6
|
@ -115,6 +115,7 @@ venv.bak/
|
||||||
*.pkl
|
*.pkl
|
||||||
*.pkl.json
|
*.pkl.json
|
||||||
*.log.json
|
*.log.json
|
||||||
|
*.trace.json
|
||||||
docs/modelzoo_statistics.md
|
docs/modelzoo_statistics.md
|
||||||
mmdet/.mim
|
mmdet/.mim
|
||||||
work_dirs/
|
work_dirs/
|
||||||
|
|
|
@ -7,23 +7,25 @@ from torch.nn import init
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
def manual_rms_norm(input, normalized_shape, weight, eps):
|
def manual_rms_norm(my_input, normalized_shape, weight, eps):
|
||||||
# layer norm should always be calculated in float32
|
# layer norm should always be calculated in float32
|
||||||
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
||||||
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
||||||
input = input * torch.rsqrt(variance + eps)
|
my_input = my_input * torch.rsqrt(variance + eps)
|
||||||
|
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return input
|
return my_input
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
# convert into half-precision if necessary
|
||||||
if weight.dtype in [torch.float16, torch.bfloat16]:
|
if weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
input = input.to(weight.dtype)
|
my_input = my_input.to(weight.dtype)
|
||||||
|
|
||||||
return weight * input
|
return weight * my_input
|
||||||
|
|
||||||
|
|
||||||
class RMSNormTorch(torch.nn.Module):
|
class RMSNormTorch(torch.nn.Module):
|
||||||
|
"""A custom PyTorch module for RMS normalization."""
|
||||||
|
|
||||||
def __init__(self, normalized_shape, eps=1e-5):
|
def __init__(self, normalized_shape, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -34,8 +36,8 @@ class RMSNormTorch(torch.nn.Module):
|
||||||
self.weight = Parameter(torch.empty(*normalized_shape))
|
self.weight = Parameter(torch.empty(*normalized_shape))
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, _input: torch.Tensor):
|
||||||
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
|
return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
init.ones_(self.weight)
|
init.ones_(self.weight)
|
||||||
|
|
|
@ -88,15 +88,17 @@ def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||||
|
|
||||||
|
|
||||||
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
|
||||||
assert input.dtype == grad_output.dtype
|
assert my_input.dtype == grad_output.dtype
|
||||||
grad_weight = torch.matmul(grad_output.t(), input)
|
grad_weight = torch.matmul(grad_output.t(), my_input)
|
||||||
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
|
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
|
||||||
return grad_weight, grad_bias
|
return grad_weight, grad_bias
|
||||||
|
|
||||||
|
|
||||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||||
class FusedDenseFuncTorch(FusedDenseFunc):
|
class FusedDenseFuncTorch(FusedDenseFunc):
|
||||||
|
"""A custom PyTorch module extending FusedDenseFunc."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, grad_output, *args):
|
def backward(ctx, grad_output, *args):
|
||||||
|
@ -173,8 +175,8 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(input_):
|
||||||
return _split(input_)
|
return _split(input_, parallel_mode=None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, parallel_mode, dim):
|
def forward(ctx, input_, parallel_mode, dim):
|
||||||
|
|
|
@ -178,6 +178,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if len(params) != 0:
|
if len(params) != 0:
|
||||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
||||||
for param in params:
|
for param in params:
|
||||||
|
setattr(param, "group_id", group_id)
|
||||||
self._param_store.set_param_to_rank(param, rank)
|
self._param_store.set_param_to_rank(param, rank)
|
||||||
|
|
||||||
# move to cpu to make room to create the flat tensor
|
# move to cpu to make room to create the flat tensor
|
||||||
|
@ -317,7 +318,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# if full, will reduce the grads already in the bucket
|
# if full, will reduce the grads already in the bucket
|
||||||
# after reduction, the bucket will be empty
|
# after reduction, the bucket will be empty
|
||||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||||
self._reduce_grads_stored_in_bucket(reduce_rank)
|
self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False)
|
||||||
|
|
||||||
# the param must not be reduced to ensure correctness
|
# the param must not be reduced to ensure correctness
|
||||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||||
|
@ -335,7 +336,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||||
self._bucket_store.add_param(param, reduce_rank)
|
self._bucket_store.add_param(param, reduce_rank)
|
||||||
|
|
||||||
def _reduce_grads_stored_in_bucket(self, reduce_rank=None):
|
def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False):
|
||||||
# reduce grads
|
# reduce grads
|
||||||
self._reduce_grads_by_rank(
|
self._reduce_grads_by_rank(
|
||||||
reduce_rank=reduce_rank,
|
reduce_rank=reduce_rank,
|
||||||
|
@ -343,14 +344,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
|
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
|
||||||
)
|
)
|
||||||
|
|
||||||
# use communication stream if overlapping
|
|
||||||
# communication with computation
|
|
||||||
if self._overlap_communication:
|
|
||||||
stream = self._comm_stream
|
|
||||||
else:
|
|
||||||
stream = torch.cuda.current_stream()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||||
|
|
||||||
for param in params_in_bucket:
|
for param in params_in_bucket:
|
||||||
|
@ -368,6 +361,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# update the flag
|
# update the flag
|
||||||
self._param_store.set_param_reduction_state(param, True)
|
self._param_store.set_param_reduction_state(param, True)
|
||||||
|
|
||||||
|
if self._param_store.belongs_to_current_rank(param):
|
||||||
|
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
|
||||||
|
else:
|
||||||
|
self._param_store.add_previous_reduced_param(param)
|
||||||
|
|
||||||
self._bucket_store.reset_by_rank(reduce_rank)
|
self._bucket_store.reset_by_rank(reduce_rank)
|
||||||
|
|
||||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||||
|
@ -385,9 +383,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||||
if self._overlap_communication:
|
if self._overlap_communication:
|
||||||
torch.cuda.synchronize()
|
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
|
||||||
stream = self._comm_stream
|
stream = self._comm_stream
|
||||||
|
stream.synchronize()
|
||||||
|
self._param_store.clear_grads_of_previous_reduced_params()
|
||||||
else:
|
else:
|
||||||
stream = torch.cuda.current_stream()
|
stream = torch.cuda.current_stream()
|
||||||
|
|
||||||
|
@ -421,6 +419,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
reduction_states = self._param_store.get_param_reduction_states()
|
reduction_states = self._param_store.get_param_reduction_states()
|
||||||
for tensor, _ in reduction_states.items():
|
for tensor, _ in reduction_states.items():
|
||||||
reduction_states[tensor] = False
|
reduction_states[tensor] = False
|
||||||
|
self._param_store.reset_reduced_data_for_compute_norm()
|
||||||
|
|
||||||
# accumulate gradient
|
# accumulate gradient
|
||||||
avg_gradients = self._grad_store._averaged_gradients
|
avg_gradients = self._grad_store._averaged_gradients
|
||||||
|
@ -469,6 +468,30 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# Gradients may not be fully synchronized here.
|
# Gradients may not be fully synchronized here.
|
||||||
|
|
||||||
|
def _compute_norm_with_stage(
|
||||||
|
self,
|
||||||
|
group_id: int = 0,
|
||||||
|
last_bucket: bool = False,
|
||||||
|
last_stage: bool = False,
|
||||||
|
previous_norm=None,
|
||||||
|
):
|
||||||
|
# compute norm for gradients that have been reduced
|
||||||
|
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
|
||||||
|
if len(params) == 0:
|
||||||
|
grads = [self.padding_grad]
|
||||||
|
params = [self.padding_tensor]
|
||||||
|
|
||||||
|
if self._clip_grad_norm > 0:
|
||||||
|
# this norm is before scaling, it will be very large
|
||||||
|
norm = compute_norm(
|
||||||
|
gradients=grads,
|
||||||
|
parameters=params,
|
||||||
|
last_stage=last_stage,
|
||||||
|
previous_norm=previous_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
return norm
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
@ -480,7 +503,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
assert closure is None, "closure is not supported by step()"
|
assert closure is None, "closure is not supported by step()"
|
||||||
|
|
||||||
timer("sync_grad").start()
|
|
||||||
# if not overlapping communication (no reduction hook is attached)
|
# if not overlapping communication (no reduction hook is attached)
|
||||||
# we need to manually reduce these gradients
|
# we need to manually reduce these gradients
|
||||||
if not self._overlap_communication:
|
if not self._overlap_communication:
|
||||||
|
@ -490,54 +512,49 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._store_and_try_reduce_grads_by_bucket(param)
|
self._store_and_try_reduce_grads_by_bucket(param)
|
||||||
|
|
||||||
# we need to reduce the gradients left in the communication bucket
|
# we need to reduce the gradients left in the communication bucket
|
||||||
self._reduce_grads_stored_in_bucket()
|
self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True)
|
||||||
|
|
||||||
|
# compute norm for gradients in the before bucket
|
||||||
|
groups_norms = []
|
||||||
|
for group_id in range(self.num_param_groups):
|
||||||
|
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||||
|
|
||||||
# clear reduced grads
|
# clear reduced grads
|
||||||
if self._overlap_communication:
|
if self._overlap_communication:
|
||||||
torch.cuda.synchronize()
|
# grads in the last bucket is reduced
|
||||||
|
self._comm_stream.synchronize()
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
self._param_store.clear_grads_of_previous_reduced_params()
|
||||||
|
|
||||||
|
# compute norm for gradients in the last bucket
|
||||||
|
total_norms = []
|
||||||
|
for group_id in range(self.num_param_groups):
|
||||||
|
total_norms.append(
|
||||||
|
self._compute_norm_with_stage(
|
||||||
|
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
timer("sync_grad").start()
|
||||||
self._sync_grad()
|
self._sync_grad()
|
||||||
timer("sync_grad").stop()
|
timer("sync_grad").stop()
|
||||||
|
|
||||||
return self._step(closure=closure)
|
return self._step(closure=closure, norms=total_norms)
|
||||||
|
|
||||||
def _step(self, closure=None):
|
def _step(self, closure=None, norms=None):
|
||||||
assert closure is None, "closure is not supported by step()"
|
assert closure is None, "closure is not supported by step()"
|
||||||
|
|
||||||
# check for overflow
|
# check for overflow
|
||||||
found_inf = self._check_overflow()
|
found_inf = False
|
||||||
|
# if there is INF values in grades, compute_norm func would also returns -1
|
||||||
|
# thus, we try to avoid call _check_overflow here
|
||||||
|
# found_inf = self._check_overflow()
|
||||||
# Because you may encounter inf when computing norm
|
# Because you may encounter inf when computing norm
|
||||||
timer("cal_norm").start()
|
|
||||||
norm_groups = []
|
|
||||||
for group_id in range(self.num_param_groups):
|
|
||||||
# compute norm
|
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
|
||||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
|
||||||
parameters = self._param_store.get_fp16_params_by_rank_group(
|
|
||||||
group_id=group_id, rank=self._zero_local_rank
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# in order to prevent collection communication from hanging,
|
|
||||||
# we need to involve rank that are not assigned parameters in compute_norm(),
|
|
||||||
# so we give them a fp16 vector of 0 values.
|
|
||||||
gradients = [self.padding_grad]
|
|
||||||
parameters = [self.padding_tensor]
|
|
||||||
|
|
||||||
if self._clip_grad_norm > 0:
|
if -1 in norms:
|
||||||
# this norm is before scaling, it will be very large
|
|
||||||
norm_group = compute_norm(
|
|
||||||
gradients=gradients,
|
|
||||||
parameters=parameters,
|
|
||||||
)
|
|
||||||
if norm_group == -1:
|
|
||||||
timer("cal_norm").stop()
|
|
||||||
found_inf = True
|
found_inf = True
|
||||||
break
|
|
||||||
norm_groups.append(norm_group)
|
|
||||||
|
|
||||||
loss_scale = float(self.loss_scale.item()) # backup
|
loss_scale = float(self.loss_scale.item()) # backup
|
||||||
if not gpc.config.model.dtype is torch.float32:
|
if gpc.config.model.dtype is not torch.float32:
|
||||||
self.grad_scaler.update(found_inf)
|
self.grad_scaler.update(found_inf)
|
||||||
# update loss scale if overflow occurs
|
# update loss scale if overflow occurs
|
||||||
if found_inf:
|
if found_inf:
|
||||||
|
@ -550,7 +567,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# copy the grad of fp16 param to fp32 param
|
# copy the grad of fp16 param to fp32 param
|
||||||
single_grad_partition_groups = []
|
single_grad_partition_groups = []
|
||||||
global_norm = 0
|
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
# compute norm
|
# compute norm
|
||||||
# The following operations are performed only on the rank to which parameters are assigned.
|
# The following operations are performed only on the rank to which parameters are assigned.
|
||||||
|
@ -559,13 +575,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# create flat gradient for the flat fp32 params
|
# create flat gradient for the flat fp32 params
|
||||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||||
|
with torch.no_grad():
|
||||||
flat_fp16_avg_grads = flatten(gradients)
|
flat_fp16_avg_grads = flatten(gradients)
|
||||||
self._grad_store.reset_average_gradients_by_group(group_id)
|
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||||
del gradients # release cuda memory
|
gradients = None # release cuda memory
|
||||||
|
|
||||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||||
del flat_fp16_avg_grads # release cuda memory
|
flat_fp16_avg_grads = None # release cuda memory
|
||||||
|
|
||||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||||
assert (
|
assert (
|
||||||
|
@ -578,15 +595,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
# get the global norm
|
# get the global norm
|
||||||
|
global_norm_groups = []
|
||||||
if self._clip_grad_norm > 0:
|
if self._clip_grad_norm > 0:
|
||||||
global_norm = sum(norm_groups) ** 0.5
|
for norm in norms:
|
||||||
|
global_norm_groups.append(norm**0.5)
|
||||||
|
|
||||||
# the following operations are performed only on the rank to which parameters are assigned.
|
# the following operations are performed only on the rank to which parameters are assigned.
|
||||||
if not gpc.config.model.dtype is torch.float32:
|
if gpc.config.model.dtype is not torch.float32:
|
||||||
if len(single_grad_partition_groups) != 0:
|
if len(single_grad_partition_groups) != 0:
|
||||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
|
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
|
||||||
|
|
||||||
timer("cal_norm").stop()
|
|
||||||
# update the parameters
|
# update the parameters
|
||||||
timer("step").start()
|
timer("step").start()
|
||||||
|
|
||||||
|
@ -611,7 +629,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
timer("step").stop()
|
timer("step").stop()
|
||||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||||
# so synchronization is maintained
|
# so synchronization is maintained
|
||||||
return True, global_norm / loss_scale
|
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
|
||||||
|
|
||||||
def broadcast_params(self, overlap=False):
|
def broadcast_params(self, overlap=False):
|
||||||
handles = []
|
handles = []
|
||||||
|
@ -655,18 +673,20 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
return self._found_overflow.item() > 0
|
return self._found_overflow.item() > 0
|
||||||
|
|
||||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm, loss_scale):
|
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
|
||||||
# compute combined scale factor for this group
|
# compute combined scale factor for this group
|
||||||
combined_scale = loss_scale
|
combined_scale_groups = []
|
||||||
|
|
||||||
if self._clip_grad_norm > 0.0:
|
if self._clip_grad_norm > 0.0:
|
||||||
# norm is in fact norm*scale
|
# norm is in fact norm*scale
|
||||||
|
for group_id, total_norm in enumerate(total_norm_groups):
|
||||||
|
combined_scale_groups.append(loss_scale)
|
||||||
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||||
if clip > 1.0:
|
if clip > 1.0:
|
||||||
combined_scale = clip * loss_scale
|
combined_scale_groups[group_id] = clip * loss_scale
|
||||||
|
|
||||||
for grad in grad_groups_flat:
|
for group_id, grad in enumerate(grad_groups_flat):
|
||||||
grad.data.mul_(1.0 / combined_scale)
|
grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||||
|
|
||||||
def clip_grad_norm(self, model, max_norm):
|
def clip_grad_norm(self, model, max_norm):
|
||||||
# will conduct in the step()
|
# will conduct in the step()
|
||||||
|
|
|
@ -152,6 +152,11 @@ class ParameterStore(BaseStore):
|
||||||
self._is_param_reduced = dict()
|
self._is_param_reduced = dict()
|
||||||
self._reduced_param = []
|
self._reduced_param = []
|
||||||
|
|
||||||
|
self._former_bucket_reduced_param = {}
|
||||||
|
self._last_bucket_reduced_param = {}
|
||||||
|
self._former_bucket_reduced_grad = {}
|
||||||
|
self._last_bucket_reduced_grad = {}
|
||||||
|
|
||||||
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
||||||
"""
|
"""
|
||||||
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
||||||
|
@ -223,6 +228,35 @@ class ParameterStore(BaseStore):
|
||||||
def add_previous_reduced_param(self, tensor):
|
def add_previous_reduced_param(self, tensor):
|
||||||
self._reduced_param.append(tensor)
|
self._reduced_param.append(tensor)
|
||||||
|
|
||||||
|
def add_reduced_param_for_compute_norm(self, param, last_bucket=False):
|
||||||
|
group_id = getattr(param, "group_id")
|
||||||
|
if last_bucket:
|
||||||
|
if group_id not in self._last_bucket_reduced_param:
|
||||||
|
self._last_bucket_reduced_param[group_id] = []
|
||||||
|
self._last_bucket_reduced_grad[group_id] = []
|
||||||
|
|
||||||
|
self._last_bucket_reduced_param[group_id].append(param)
|
||||||
|
self._last_bucket_reduced_grad[group_id].append(param.grad)
|
||||||
|
else:
|
||||||
|
if group_id not in self._former_bucket_reduced_param:
|
||||||
|
self._former_bucket_reduced_param[group_id] = []
|
||||||
|
self._former_bucket_reduced_grad[group_id] = []
|
||||||
|
|
||||||
|
self._former_bucket_reduced_param[group_id].append(param)
|
||||||
|
self._former_bucket_reduced_grad[group_id].append(param.grad)
|
||||||
|
|
||||||
|
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
|
||||||
|
if not last_bucket:
|
||||||
|
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
||||||
|
else:
|
||||||
|
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
||||||
|
|
||||||
|
def reset_reduced_data_for_compute_norm(self):
|
||||||
|
self._former_bucket_reduced_param = {}
|
||||||
|
self._last_bucket_reduced_param = {}
|
||||||
|
self._former_bucket_reduced_grad = {}
|
||||||
|
self._last_bucket_reduced_grad = {}
|
||||||
|
|
||||||
def clear_grads_of_previous_reduced_params(self):
|
def clear_grads_of_previous_reduced_params(self):
|
||||||
if len(self._reduced_param) > 0:
|
if len(self._reduced_param) > 0:
|
||||||
for param in self._reduced_param:
|
for param in self._reduced_param:
|
||||||
|
|
|
@ -21,6 +21,7 @@ logger = get_logger(__file__)
|
||||||
try:
|
try:
|
||||||
import amp_C
|
import amp_C
|
||||||
from apex.multi_tensor_apply import multi_tensor_applier
|
from apex.multi_tensor_apply import multi_tensor_applier
|
||||||
|
|
||||||
APEX_AVAILABLE = True
|
APEX_AVAILABLE = True
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
||||||
|
@ -162,6 +163,7 @@ def sync_param(flat_tensor, tensor_list):
|
||||||
for p, q in zip(tensor_list, updated_params):
|
for p, q in zip(tensor_list, updated_params):
|
||||||
p.data = q.data
|
p.data = q.data
|
||||||
|
|
||||||
|
|
||||||
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
||||||
# Convert tensor_list elements to torch.float32
|
# Convert tensor_list elements to torch.float32
|
||||||
tensor_list = [tensor.float() for tensor in tensor_list]
|
tensor_list = [tensor.float() for tensor in tensor_list]
|
||||||
|
@ -175,6 +177,7 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
||||||
|
|
||||||
return l2_norm, per_tensor_norm
|
return l2_norm, per_tensor_norm
|
||||||
|
|
||||||
|
|
||||||
def calc_l2_norm(grads):
|
def calc_l2_norm(grads):
|
||||||
norm = 0.0
|
norm = 0.0
|
||||||
if len(grads) > 0:
|
if len(grads) > 0:
|
||||||
|
@ -187,6 +190,7 @@ def calc_l2_norm(grads):
|
||||||
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
|
||||||
def calc_lp(grads, norm_type):
|
def calc_lp(grads, norm_type):
|
||||||
norm = 0.0
|
norm = 0.0
|
||||||
for grad in grads:
|
for grad in grads:
|
||||||
|
@ -195,7 +199,7 @@ def calc_lp(grads, norm_type):
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
|
||||||
def compute_norm(gradients, parameters, norm_type=2):
|
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2):
|
||||||
"""Get the norm
|
"""Get the norm
|
||||||
Arguments:
|
Arguments:
|
||||||
gradients (Iterable[Tensor]): The gradient value.
|
gradients (Iterable[Tensor]): The gradient value.
|
||||||
|
@ -215,6 +219,13 @@ def compute_norm(gradients, parameters, norm_type=2):
|
||||||
if norm_type == inf:
|
if norm_type == inf:
|
||||||
total_norm = max(g.data.abs().max() for g in gradients)
|
total_norm = max(g.data.abs().max() for g in gradients)
|
||||||
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
|
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
|
||||||
|
|
||||||
|
if last_stage is False:
|
||||||
|
return total_norm_cuda
|
||||||
|
|
||||||
|
if previous_norm is not None:
|
||||||
|
total_norm_cuda = max(total_norm_cuda, previous_norm)
|
||||||
|
|
||||||
# Take max across all model-parallel GPUs.
|
# Take max across all model-parallel GPUs.
|
||||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||||
|
@ -261,6 +272,12 @@ def compute_norm(gradients, parameters, norm_type=2):
|
||||||
|
|
||||||
total_norm = tensor_parallel_norm
|
total_norm = tensor_parallel_norm
|
||||||
|
|
||||||
|
if last_stage is False:
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
if previous_norm is not None:
|
||||||
|
total_norm = total_norm + previous_norm
|
||||||
|
|
||||||
# Sum across all model-parallel GPUs.
|
# Sum across all model-parallel GPUs.
|
||||||
if gpc.is_initialized(ParallelMode.MODEL):
|
if gpc.is_initialized(ParallelMode.MODEL):
|
||||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||||
|
|
10
train.py
10
train.py
|
@ -7,6 +7,7 @@ import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -603,12 +604,12 @@ def main(args):
|
||||||
trainer_result = trainer.step()
|
trainer_result = trainer.step()
|
||||||
assert trainer_result is not None
|
assert trainer_result is not None
|
||||||
|
|
||||||
success_update, grad_norm = trainer_result
|
success_update, grad_norm_groups = trainer_result
|
||||||
if success_update: # update parameters successfully
|
if success_update: # update parameters successfully
|
||||||
train_state.step_count += 1
|
train_state.step_count += 1
|
||||||
else:
|
else:
|
||||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||||
send_alert_message(
|
send_alert_message(
|
||||||
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||||
|
@ -628,7 +629,7 @@ def main(args):
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
grad_norm=grad_norm,
|
grad_norm=np.array(grad_norm_groups),
|
||||||
metric=metric,
|
metric=metric,
|
||||||
update_panel=uniscale_logger is not None,
|
update_panel=uniscale_logger is not None,
|
||||||
)
|
)
|
||||||
|
@ -668,7 +669,6 @@ if __name__ == "__main__":
|
||||||
main(args)
|
main(args)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
|
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
|
||||||
exc_info=traceback.format_exc(),
|
|
||||||
)
|
)
|
||||||
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
||||||
|
|
Loading…
Reference in New Issue