From 4df2c47472941d8426e55ce236f58a7727f3001d Mon Sep 17 00:00:00 2001 From: Qu Wenwen Date: Fri, 27 Oct 2023 15:34:48 +0800 Subject: [PATCH] refactor code --- internlm/solver/optimizer/utils.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 2f7e21f..e0e2390 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -252,16 +252,6 @@ def reduce_grads(gradients, parameters, fine_grained=False): return parallel_grads -def reduce_moe_norm(total_norm): - pg = gpc.get_group(ParallelMode.EXPERT) - scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - total_norm = scaled_norm_tensor.item() - - return total_norm - - def compute_norm( gradients, parameters, @@ -344,7 +334,11 @@ def compute_norm( # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce # model and zero have been reduced!!! if zero_mode == ParallelMode.EXPERT_DATA: - total_norm = reduce_moe_norm(total_norm) + pg = gpc.get_group(ParallelMode.EXPERT) + scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + total_norm = scaled_norm_tensor.item() # Scale. if total_norm == float("inf") or total_norm == -float("inf"): @@ -433,7 +427,12 @@ def compute_param_norm( # moe if zero_mode == ParallelMode.EXPERT_DATA: - total_param_norms = reduce_moe_norm(total_param_norms) + pg = gpc.get_group(ParallelMode.EXPERT) + scaled_param_norm = torch.cuda.FloatTensor(list(total_param_norms.values()), device=get_current_device()) + scaled_param_norm = scaled_param_norm / float(gpc.get_world_size(ParallelMode.EXPERT)) + dist.all_reduce(scaled_param_norm, group=pg) + for i, param_name in enumerate(total_param_norms.keys()): + total_param_norms[param_name] = scaled_param_norm[i].item() # scale for param_name, param_norm in total_param_norms.items():