mirror of https://github.com/InternLM/InternLM
Feat/optimizer (#194)
* feat(optimier.py): reduce memory footprint and avoid _check_overflow call * feat(optimier.py): reduce memory footprint and avoid _check_overflow call * feat(optimizer.py): overlap compute norm with allreduce * update var and function name * update function compute norm (#197) Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> * feat(optimizer/hybrid_zero_optim.py): overlap gradients last bucket allreduce and compute norm (#196) * support gradients allreduce and compute norm overlap * fix para set error * remove timer cal_norm for testing * feat(optimizer/hybrid_zero_optim.py): support group global norm * format(lint): fix lint error * feat(optimizer/store.py): update code based on comment --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> Co-authored-by: huangting4201 <1538303371@qq.com>pull/199/head
parent
4e8bd39d8f
commit
ef851d16c6
|
@ -115,6 +115,7 @@ venv.bak/
|
|||
*.pkl
|
||||
*.pkl.json
|
||||
*.log.json
|
||||
*.trace.json
|
||||
docs/modelzoo_statistics.md
|
||||
mmdet/.mim
|
||||
work_dirs/
|
||||
|
|
|
@ -7,23 +7,25 @@ from torch.nn import init
|
|||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def manual_rms_norm(input, normalized_shape, weight, eps):
|
||||
def manual_rms_norm(my_input, normalized_shape, weight, eps):
|
||||
# layer norm should always be calculated in float32
|
||||
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
||||
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
||||
input = input * torch.rsqrt(variance + eps)
|
||||
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
||||
my_input = my_input * torch.rsqrt(variance + eps)
|
||||
|
||||
if weight is None:
|
||||
return input
|
||||
return my_input
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
input = input.to(weight.dtype)
|
||||
my_input = my_input.to(weight.dtype)
|
||||
|
||||
return weight * input
|
||||
return weight * my_input
|
||||
|
||||
|
||||
class RMSNormTorch(torch.nn.Module):
|
||||
"""A custom PyTorch module for RMS normalization."""
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5):
|
||||
super().__init__()
|
||||
|
||||
|
@ -34,8 +36,8 @@ class RMSNormTorch(torch.nn.Module):
|
|||
self.weight = Parameter(torch.empty(*normalized_shape))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
|
||||
def forward(self, _input: torch.Tensor):
|
||||
return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
|
|
@ -88,15 +88,17 @@ def gather_forward_split_backward(input_, parallel_mode, dim):
|
|||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
||||
assert input.dtype == grad_output.dtype
|
||||
grad_weight = torch.matmul(grad_output.t(), input)
|
||||
def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
|
||||
assert my_input.dtype == grad_output.dtype
|
||||
grad_weight = torch.matmul(grad_output.t(), my_input)
|
||||
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
|
||||
return grad_weight, grad_bias
|
||||
|
||||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFuncTorch(FusedDenseFunc):
|
||||
"""A custom PyTorch module extending FusedDenseFunc."""
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, *args):
|
||||
|
@ -173,8 +175,8 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split(input_)
|
||||
def symbolic(input_):
|
||||
return _split(input_, parallel_mode=None)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode, dim):
|
||||
|
|
|
@ -178,6 +178,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if len(params) != 0:
|
||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
||||
for param in params:
|
||||
setattr(param, "group_id", group_id)
|
||||
self._param_store.set_param_to_rank(param, rank)
|
||||
|
||||
# move to cpu to make room to create the flat tensor
|
||||
|
@ -317,7 +318,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank)
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False)
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
|
@ -335,7 +336,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
|
||||
def _reduce_grads_stored_in_bucket(self, reduce_rank=None):
|
||||
def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False):
|
||||
# reduce grads
|
||||
self._reduce_grads_by_rank(
|
||||
reduce_rank=reduce_rank,
|
||||
|
@ -343,30 +344,27 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
|
||||
)
|
||||
|
||||
# use communication stream if overlapping
|
||||
# communication with computation
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = (
|
||||
f"Parameter of size ({param.size()}) has been reduced, "
|
||||
+ "duplicate reduction will lead to arithmetic incorrectness"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if is_param_reduced:
|
||||
msg = (
|
||||
f"Parameter of size ({param.size()}) has been reduced, "
|
||||
+ "duplicate reduction will lead to arithmetic incorrectness"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
|
||||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
if self._param_store.belongs_to_current_rank(param):
|
||||
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
|
||||
else:
|
||||
self._param_store.add_previous_reduced_param(param)
|
||||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
|
||||
|
@ -385,9 +383,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
|
@ -421,6 +419,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
reduction_states = self._param_store.get_param_reduction_states()
|
||||
for tensor, _ in reduction_states.items():
|
||||
reduction_states[tensor] = False
|
||||
self._param_store.reset_reduced_data_for_compute_norm()
|
||||
|
||||
# accumulate gradient
|
||||
avg_gradients = self._grad_store._averaged_gradients
|
||||
|
@ -469,6 +468,30 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# Gradients may not be fully synchronized here.
|
||||
|
||||
def _compute_norm_with_stage(
|
||||
self,
|
||||
group_id: int = 0,
|
||||
last_bucket: bool = False,
|
||||
last_stage: bool = False,
|
||||
previous_norm=None,
|
||||
):
|
||||
# compute norm for gradients that have been reduced
|
||||
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
|
||||
if len(params) == 0:
|
||||
grads = [self.padding_grad]
|
||||
params = [self.padding_tensor]
|
||||
|
||||
if self._clip_grad_norm > 0:
|
||||
# this norm is before scaling, it will be very large
|
||||
norm = compute_norm(
|
||||
gradients=grads,
|
||||
parameters=params,
|
||||
last_stage=last_stage,
|
||||
previous_norm=previous_norm,
|
||||
)
|
||||
|
||||
return norm
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
|
@ -480,7 +503,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
"""
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
timer("sync_grad").start()
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
# we need to manually reduce these gradients
|
||||
if not self._overlap_communication:
|
||||
|
@ -490,54 +512,49 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._store_and_try_reduce_grads_by_bucket(param)
|
||||
|
||||
# we need to reduce the gradients left in the communication bucket
|
||||
self._reduce_grads_stored_in_bucket()
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True)
|
||||
|
||||
# compute norm for gradients in the before bucket
|
||||
groups_norms = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
# grads in the last bucket is reduced
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# compute norm for gradients in the last bucket
|
||||
total_norms = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
total_norms.append(
|
||||
self._compute_norm_with_stage(
|
||||
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
|
||||
)
|
||||
)
|
||||
|
||||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
return self._step(closure=closure)
|
||||
return self._step(closure=closure, norms=total_norms)
|
||||
|
||||
def _step(self, closure=None):
|
||||
def _step(self, closure=None, norms=None):
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
# check for overflow
|
||||
found_inf = self._check_overflow()
|
||||
found_inf = False
|
||||
# if there is INF values in grades, compute_norm func would also returns -1
|
||||
# thus, we try to avoid call _check_overflow here
|
||||
# found_inf = self._check_overflow()
|
||||
# Because you may encounter inf when computing norm
|
||||
timer("cal_norm").start()
|
||||
norm_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
parameters = self._param_store.get_fp16_params_by_rank_group(
|
||||
group_id=group_id, rank=self._zero_local_rank
|
||||
)
|
||||
else:
|
||||
# in order to prevent collection communication from hanging,
|
||||
# we need to involve rank that are not assigned parameters in compute_norm(),
|
||||
# so we give them a fp16 vector of 0 values.
|
||||
gradients = [self.padding_grad]
|
||||
parameters = [self.padding_tensor]
|
||||
|
||||
if self._clip_grad_norm > 0:
|
||||
# this norm is before scaling, it will be very large
|
||||
norm_group = compute_norm(
|
||||
gradients=gradients,
|
||||
parameters=parameters,
|
||||
)
|
||||
if norm_group == -1:
|
||||
timer("cal_norm").stop()
|
||||
found_inf = True
|
||||
break
|
||||
norm_groups.append(norm_group)
|
||||
if -1 in norms:
|
||||
found_inf = True
|
||||
|
||||
loss_scale = float(self.loss_scale.item()) # backup
|
||||
if not gpc.config.model.dtype is torch.float32:
|
||||
if gpc.config.model.dtype is not torch.float32:
|
||||
self.grad_scaler.update(found_inf)
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
|
@ -550,7 +567,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
global_norm = 0
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
|
@ -559,13 +575,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# create flat gradient for the flat fp32 params
|
||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
flat_fp16_avg_grads = flatten(gradients)
|
||||
with torch.no_grad():
|
||||
flat_fp16_avg_grads = flatten(gradients)
|
||||
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||
del gradients # release cuda memory
|
||||
gradients = None # release cuda memory
|
||||
|
||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||
del flat_fp16_avg_grads # release cuda memory
|
||||
flat_fp16_avg_grads = None # release cuda memory
|
||||
|
||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||
assert (
|
||||
|
@ -578,15 +595,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# unscale and clip grads
|
||||
# get the global norm
|
||||
global_norm_groups = []
|
||||
if self._clip_grad_norm > 0:
|
||||
global_norm = sum(norm_groups) ** 0.5
|
||||
for norm in norms:
|
||||
global_norm_groups.append(norm**0.5)
|
||||
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if not gpc.config.model.dtype is torch.float32:
|
||||
if gpc.config.model.dtype is not torch.float32:
|
||||
if len(single_grad_partition_groups) != 0:
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
|
||||
|
||||
timer("cal_norm").stop()
|
||||
# update the parameters
|
||||
timer("step").start()
|
||||
|
||||
|
@ -611,7 +629,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
timer("step").stop()
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
# so synchronization is maintained
|
||||
return True, global_norm / loss_scale
|
||||
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
|
||||
|
||||
def broadcast_params(self, overlap=False):
|
||||
handles = []
|
||||
|
@ -655,18 +673,20 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm, loss_scale):
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = loss_scale
|
||||
combined_scale_groups = []
|
||||
|
||||
if self._clip_grad_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1.0:
|
||||
combined_scale = clip * loss_scale
|
||||
for group_id, total_norm in enumerate(total_norm_groups):
|
||||
combined_scale_groups.append(loss_scale)
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1.0:
|
||||
combined_scale_groups[group_id] = clip * loss_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1.0 / combined_scale)
|
||||
for group_id, grad in enumerate(grad_groups_flat):
|
||||
grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||
|
||||
def clip_grad_norm(self, model, max_norm):
|
||||
# will conduct in the step()
|
||||
|
|
|
@ -152,6 +152,11 @@ class ParameterStore(BaseStore):
|
|||
self._is_param_reduced = dict()
|
||||
self._reduced_param = []
|
||||
|
||||
self._former_bucket_reduced_param = {}
|
||||
self._last_bucket_reduced_param = {}
|
||||
self._former_bucket_reduced_grad = {}
|
||||
self._last_bucket_reduced_grad = {}
|
||||
|
||||
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
||||
"""
|
||||
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
||||
|
@ -223,6 +228,35 @@ class ParameterStore(BaseStore):
|
|||
def add_previous_reduced_param(self, tensor):
|
||||
self._reduced_param.append(tensor)
|
||||
|
||||
def add_reduced_param_for_compute_norm(self, param, last_bucket=False):
|
||||
group_id = getattr(param, "group_id")
|
||||
if last_bucket:
|
||||
if group_id not in self._last_bucket_reduced_param:
|
||||
self._last_bucket_reduced_param[group_id] = []
|
||||
self._last_bucket_reduced_grad[group_id] = []
|
||||
|
||||
self._last_bucket_reduced_param[group_id].append(param)
|
||||
self._last_bucket_reduced_grad[group_id].append(param.grad)
|
||||
else:
|
||||
if group_id not in self._former_bucket_reduced_param:
|
||||
self._former_bucket_reduced_param[group_id] = []
|
||||
self._former_bucket_reduced_grad[group_id] = []
|
||||
|
||||
self._former_bucket_reduced_param[group_id].append(param)
|
||||
self._former_bucket_reduced_grad[group_id].append(param.grad)
|
||||
|
||||
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
|
||||
if not last_bucket:
|
||||
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
||||
else:
|
||||
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
||||
|
||||
def reset_reduced_data_for_compute_norm(self):
|
||||
self._former_bucket_reduced_param = {}
|
||||
self._last_bucket_reduced_param = {}
|
||||
self._former_bucket_reduced_grad = {}
|
||||
self._last_bucket_reduced_grad = {}
|
||||
|
||||
def clear_grads_of_previous_reduced_params(self):
|
||||
if len(self._reduced_param) > 0:
|
||||
for param in self._reduced_param:
|
||||
|
|
|
@ -21,6 +21,7 @@ logger = get_logger(__file__)
|
|||
try:
|
||||
import amp_C
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
APEX_AVAILABLE = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
||||
|
@ -162,6 +163,7 @@ def sync_param(flat_tensor, tensor_list):
|
|||
for p, q in zip(tensor_list, updated_params):
|
||||
p.data = q.data
|
||||
|
||||
|
||||
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
||||
# Convert tensor_list elements to torch.float32
|
||||
tensor_list = [tensor.float() for tensor in tensor_list]
|
||||
|
@ -175,6 +177,7 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
|||
|
||||
return l2_norm, per_tensor_norm
|
||||
|
||||
|
||||
def calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
|
@ -187,6 +190,7 @@ def calc_l2_norm(grads):
|
|||
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||
return norm
|
||||
|
||||
|
||||
def calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
|
@ -195,7 +199,7 @@ def calc_lp(grads, norm_type):
|
|||
return norm
|
||||
|
||||
|
||||
def compute_norm(gradients, parameters, norm_type=2):
|
||||
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2):
|
||||
"""Get the norm
|
||||
Arguments:
|
||||
gradients (Iterable[Tensor]): The gradient value.
|
||||
|
@ -215,6 +219,13 @@ def compute_norm(gradients, parameters, norm_type=2):
|
|||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
|
||||
|
||||
if last_stage is False:
|
||||
return total_norm_cuda
|
||||
|
||||
if previous_norm is not None:
|
||||
total_norm_cuda = max(total_norm_cuda, previous_norm)
|
||||
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||
|
@ -261,6 +272,12 @@ def compute_norm(gradients, parameters, norm_type=2):
|
|||
|
||||
total_norm = tensor_parallel_norm
|
||||
|
||||
if last_stage is False:
|
||||
return total_norm
|
||||
|
||||
if previous_norm is not None:
|
||||
total_norm = total_norm + previous_norm
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||
|
|
10
train.py
10
train.py
|
@ -7,6 +7,7 @@ import traceback
|
|||
from functools import partial
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
@ -603,12 +604,12 @@ def main(args):
|
|||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
||||
success_update, grad_norm = trainer_result
|
||||
success_update, grad_norm_groups = trainer_result
|
||||
if success_update: # update parameters successfully
|
||||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||
|
@ -628,7 +629,7 @@ def main(args):
|
|||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
grad_norm=grad_norm,
|
||||
grad_norm=np.array(grad_norm_groups),
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
@ -668,7 +669,6 @@ if __name__ == "__main__":
|
|||
main(args)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}",
|
||||
exc_info=traceback.format_exc(),
|
||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
|
||||
)
|
||||
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
||||
|
|
Loading…
Reference in New Issue