diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index dfe0a4a..2528ef5 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -533,7 +533,14 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=True, ) - return norm + # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce + # model and zero have been reduced!!! + pg = gpc.get_group(ParallelMode.DATA) + scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + all_groups_norm = scaled_norm_tensor.item() + return all_groups_norm def step(self, closure=None): """Performs a single optimization step. @@ -586,19 +593,6 @@ class HybridZeroOptimizer(BaseOptimizer): return self._step(closure=closure, norms=total_norms) - def _get_norm_with_moe_layers(self, norm): - # all_groups_norm_old = all_groups_norm - # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce - pg = gpc.get_group(ParallelMode.DATA) - scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) - scaled_norm_tensor = torch.tensor( - scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float - ) - dist.all_reduce(scaled_norm_tensor, group=pg) - all_groups_norm = scaled_norm_tensor.item() - # print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") - return all_groups_norm - def _step(self, closure=None, norms=None): assert closure is None, "closure is not supported by step()" @@ -656,8 +650,6 @@ class HybridZeroOptimizer(BaseOptimizer): global_norm_groups = [] if self._clip_grad_norm > 0: for group_id in range(self.num_param_groups): - if self._is_moe_group(self.optim.param_groups[group_id]): - self._get_norm_with_moe_layers(norms[group_id]) global_norm_groups.append(norms[group_id] ** 0.5) # the following operations are performed only on the rank to which parameters are assigned.