fix merge bugs

pull/182/head
Wenwen Qu 2023-09-27 11:17:03 +08:00
parent 8a63cb51ef
commit 2a70262ceb
2 changed files with 12 additions and 6 deletions

View File

@ -152,9 +152,9 @@ class NaiveAMPModel(nn.Module):
Set module to fp32 and register automatic conversion hook in the forward pass.
The fp32 modules are marked by set_fp32_attr_to_module(.)
"""
dtype = torch.float32
fp32_dtype = torch.float32
def to_dtype(x, dtype=dtype):
def to_dtype(x, dtype=fp32_dtype):
if isinstance(x, Tensor) and x.dtype != dtype:
return x.to(dtype)
return x
@ -186,6 +186,6 @@ class NaiveAMPModel(nn.Module):
# register_forward_pre_hook for transformer/embeding/norm/xxx block
for sub_module in modules:
if module_has_fp32_attr(sub_module):
sub_module.to(dtype)
sub_module.to(fp32_dtype)
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))

View File

@ -38,7 +38,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# create new groups for fp32, norm, moe gate and moe expert
new_groups = {}
new_groups["fp32"] = {"name": "fp32", "params": []}
if gpc.config.model.num_experts > 1:
if gpc.config.get("model_type") == "INTERNLM_MoE" and gpc.config.model.num_experts > 1:
# norm and gate are special group to force sync (when enable MoE).
for key in ["gate", "norm"]:
new_groups[key] = {"name": key, key: True, "params": []}
@ -57,7 +57,11 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
# then fp32 group and the moe group.
for param in pgroup["params"]:
if gpc.config.model.num_experts > 1 and is_norm_param(param):
if (
gpc.config.get("model_type") == "INTERNLM_MoE"
and gpc.config.model.num_experts > 1
and is_norm_param(param)
):
new_groups["norm"]["params"].append(param)
# gate param means MoE is enabled
elif is_gate_param(param):
@ -73,7 +77,9 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# bf16 param group, which is the first group in the param groups
pgroup["params"] = origin_params
param_groups.extend(new_groups.values())
for _, g in new_groups.items():
if g["params"]:
param_groups.append(g)
return tuple(param_groups)