restore logic for empty fp32 group

pull/382/head
Qu Wenwen 2023-10-07 13:27:13 +08:00
parent e1ecfa51ec
commit 11a4a8bb44
1 changed files with 2 additions and 8 deletions

View File

@ -1,9 +1,7 @@
from typing import Dict, Tuple
import torch
import torch.distributed as dist
from internlm.core.context.parallel_context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
@ -75,12 +73,8 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# bf16 param group, which is the first group in the param groups
pgroup["params"] = origin_params
for _, g in new_groups.items():
# remove empty group, especially for fp32 group
is_empty = torch.tensor(bool(g["params"]), device=torch.cuda.current_device())
dist.all_reduce(is_empty, group=gpc.get_group(ParallelMode.MODEL))
if is_empty:
param_groups.append(g)
# param groups may contain empty groups, such as fp32
param_groups.extend(new_groups.values())
return tuple(param_groups)