mirror of https://github.com/InternLM/InternLM
fix bugs with _compute_norm_with_moe_group
parent
6cf0fec314
commit
b10e5132fe
|
@ -537,6 +537,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
gradients=gradients,
|
||||
parameters=parameters,
|
||||
last_stage=True,
|
||||
is_moe_group=True,
|
||||
)
|
||||
|
||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||
|
|
|
@ -212,7 +212,7 @@ def calc_lp(grads, norm_type):
|
|||
return norm
|
||||
|
||||
|
||||
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2):
|
||||
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, is_moe_group=False):
|
||||
"""Get the norm
|
||||
Arguments:
|
||||
gradients (Iterable[Tensor]): The gradient value.
|
||||
|
@ -305,7 +305,8 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
|
|||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
if not is_moe_group:
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
|
Loading…
Reference in New Issue