From b10e5132fe2bf54883b250c91c806f2a76a32588 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 8 Sep 2023 18:09:13 +0800 Subject: [PATCH] fix bugs with _compute_norm_with_moe_group --- internlm/solver/optimizer/hybrid_zero_optim.py | 1 + internlm/solver/optimizer/utils.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 3e43b3a..533904f 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -537,6 +537,7 @@ class HybridZeroOptimizer(BaseOptimizer): gradients=gradients, parameters=parameters, last_stage=True, + is_moe_group=True, ) # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 38e4560..f3a987c 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -212,7 +212,7 @@ def calc_lp(grads, norm_type): return norm -def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2): +def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, is_moe_group=False): """Get the norm Arguments: gradients (Iterable[Tensor]): The gradient value. @@ -305,7 +305,8 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no # This is because we use zero1, so we need to use this reduction. # TODO: Check zero group to be a subset of dp group. - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1)) + if not is_moe_group: + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1)) if torch.is_tensor(total_norm): total_norm = total_norm.item()