From 11a4a8bb446dfe42755f3ed364b1ede882c883dd Mon Sep 17 00:00:00 2001 From: Qu Wenwen Date: Sat, 7 Oct 2023 13:27:13 +0800 Subject: [PATCH] restore logic for empty fp32 group --- internlm/train/utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 14874f3..0e249fe 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -1,9 +1,7 @@ from typing import Dict, Tuple import torch -import torch.distributed as dist -from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param @@ -75,12 +73,8 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # bf16 param group, which is the first group in the param groups pgroup["params"] = origin_params - for _, g in new_groups.items(): - # remove empty group, especially for fp32 group - is_empty = torch.tensor(bool(g["params"]), device=torch.cuda.current_device()) - dist.all_reduce(is_empty, group=gpc.get_group(ParallelMode.MODEL)) - if is_empty: - param_groups.append(g) + # param groups may contain empty groups, such as fp32 + param_groups.extend(new_groups.values()) return tuple(param_groups)