fix bugs with _compute_norm_with_moe_group

pull/182/head
Wenwen Qu 2023-09-08 18:09:13 +08:00
parent 6cf0fec314
commit b10e5132fe
2 changed files with 4 additions and 2 deletions

View File

@ -537,6 +537,7 @@ class HybridZeroOptimizer(BaseOptimizer):
gradients=gradients,
parameters=parameters,
last_stage=True,
is_moe_group=True,
)
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce

View File

@ -212,7 +212,7 @@ def calc_lp(grads, norm_type):
return norm
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2):
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, is_moe_group=False):
"""Get the norm
Arguments:
gradients (Iterable[Tensor]): The gradient value.
@ -305,7 +305,8 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
# This is because we use zero1, so we need to use this reduction.
# TODO: Check zero group to be a subset of dp group.
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
if not is_moe_group:
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
if torch.is_tensor(total_norm):
total_norm = total_norm.item()