From 0ab3de89942beff733ecc5f3fae18e7af9451314 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 22 Aug 2023 14:00:07 +0800 Subject: [PATCH 1/2] fix bugs with compute moe norm --- .../solver/optimizer/hybrid_zero_optim.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0d5ce4b..dfe0a4a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -522,6 +522,19 @@ class HybridZeroOptimizer(BaseOptimizer): return norm + def _compute_norm_with_moe_group(self, group_id): + parameters = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank) + # wo do not get the average grad for moe parameters, so we have to constuct + # the gradients list hear. Maybe this can be optimized. + gradients = [p.grad for p in parameters] + norm = compute_norm( + gradients=gradients, + parameters=parameters, + last_stage=True, + ) + + return norm + def step(self, closure=None): """Performs a single optimization step. @@ -559,12 +572,14 @@ class HybridZeroOptimizer(BaseOptimizer): # compute norm for gradients in the last bucket total_norms = [] for group_id in range(self.num_param_groups): - total_norms.append( - self._compute_norm_with_stage( - group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id] + if self._is_moe_group(self.optim.param_groups[group_id]): + total_norms.append(self._compute_norm_with_moe_group(group_id=group_id)) + else: + total_norms.append( + self._compute_norm_with_stage( + group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id] + ) ) - ) - timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() From 94b8b18a49ce6ce0755800a56d3c3b5088043835 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 22 Aug 2023 14:30:13 +0800 Subject: [PATCH 2/2] optimize code with moe norm computing --- .../solver/optimizer/hybrid_zero_optim.py | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index dfe0a4a..2528ef5 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -533,7 +533,14 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=True, ) - return norm + # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce + # model and zero have been reduced!!! + pg = gpc.get_group(ParallelMode.DATA) + scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + all_groups_norm = scaled_norm_tensor.item() + return all_groups_norm def step(self, closure=None): """Performs a single optimization step. @@ -586,19 +593,6 @@ class HybridZeroOptimizer(BaseOptimizer): return self._step(closure=closure, norms=total_norms) - def _get_norm_with_moe_layers(self, norm): - # all_groups_norm_old = all_groups_norm - # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce - pg = gpc.get_group(ParallelMode.DATA) - scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) - scaled_norm_tensor = torch.tensor( - scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float - ) - dist.all_reduce(scaled_norm_tensor, group=pg) - all_groups_norm = scaled_norm_tensor.item() - # print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") - return all_groups_norm - def _step(self, closure=None, norms=None): assert closure is None, "closure is not supported by step()" @@ -656,8 +650,6 @@ class HybridZeroOptimizer(BaseOptimizer): global_norm_groups = [] if self._clip_grad_norm > 0: for group_id in range(self.num_param_groups): - if self._is_moe_group(self.optim.param_groups[group_id]): - self._get_norm_with_moe_layers(norms[group_id]) global_norm_groups.append(norms[group_id] ** 0.5) # the following operations are performed only on the rank to which parameters are assigned.