mirror of https://github.com/InternLM/InternLM
restore logic for empty fp32 group
parent
e1ecfa51ec
commit
11a4a8bb44
|
@ -1,9 +1,7 @@
|
|||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context.parallel_context import ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
|
||||
|
||||
|
@ -75,12 +73,8 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
# bf16 param group, which is the first group in the param groups
|
||||
pgroup["params"] = origin_params
|
||||
|
||||
for _, g in new_groups.items():
|
||||
# remove empty group, especially for fp32 group
|
||||
is_empty = torch.tensor(bool(g["params"]), device=torch.cuda.current_device())
|
||||
dist.all_reduce(is_empty, group=gpc.get_group(ParallelMode.MODEL))
|
||||
if is_empty:
|
||||
param_groups.append(g)
|
||||
# param groups may contain empty groups, such as fp32
|
||||
param_groups.extend(new_groups.values())
|
||||
|
||||
return tuple(param_groups)
|
||||
|
||||
|
|
Loading…
Reference in New Issue