fix merge bugs

pull/182/head
Wenwen Qu 2023-09-27 11:17:03 +08:00
parent 8a63cb51ef
commit 2a70262ceb
2 changed files with 12 additions and 6 deletions

View File

@ -152,9 +152,9 @@ class NaiveAMPModel(nn.Module):
Set module to fp32 and register automatic conversion hook in the forward pass. Set module to fp32 and register automatic conversion hook in the forward pass.
The fp32 modules are marked by set_fp32_attr_to_module(.) The fp32 modules are marked by set_fp32_attr_to_module(.)
""" """
dtype = torch.float32 fp32_dtype = torch.float32
def to_dtype(x, dtype=dtype): def to_dtype(x, dtype=fp32_dtype):
if isinstance(x, Tensor) and x.dtype != dtype: if isinstance(x, Tensor) and x.dtype != dtype:
return x.to(dtype) return x.to(dtype)
return x return x
@ -186,6 +186,6 @@ class NaiveAMPModel(nn.Module):
# register_forward_pre_hook for transformer/embeding/norm/xxx block # register_forward_pre_hook for transformer/embeding/norm/xxx block
for sub_module in modules: for sub_module in modules:
if module_has_fp32_attr(sub_module): if module_has_fp32_attr(sub_module):
sub_module.to(dtype) sub_module.to(fp32_dtype)
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32)) sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))

View File

@ -38,7 +38,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# create new groups for fp32, norm, moe gate and moe expert # create new groups for fp32, norm, moe gate and moe expert
new_groups = {} new_groups = {}
new_groups["fp32"] = {"name": "fp32", "params": []} new_groups["fp32"] = {"name": "fp32", "params": []}
if gpc.config.model.num_experts > 1: if gpc.config.get("model_type") == "INTERNLM_MoE" and gpc.config.model.num_experts > 1:
# norm and gate are special group to force sync (when enable MoE). # norm and gate are special group to force sync (when enable MoE).
for key in ["gate", "norm"]: for key in ["gate", "norm"]:
new_groups[key] = {"name": key, key: True, "params": []} new_groups[key] = {"name": key, key: True, "params": []}
@ -57,7 +57,11 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# first split the norm and gate groups, which are special case to force sync (when enable MoE), # first split the norm and gate groups, which are special case to force sync (when enable MoE),
# then fp32 group and the moe group. # then fp32 group and the moe group.
for param in pgroup["params"]: for param in pgroup["params"]:
if gpc.config.model.num_experts > 1 and is_norm_param(param): if (
gpc.config.get("model_type") == "INTERNLM_MoE"
and gpc.config.model.num_experts > 1
and is_norm_param(param)
):
new_groups["norm"]["params"].append(param) new_groups["norm"]["params"].append(param)
# gate param means MoE is enabled # gate param means MoE is enabled
elif is_gate_param(param): elif is_gate_param(param):
@ -73,7 +77,9 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# bf16 param group, which is the first group in the param groups # bf16 param group, which is the first group in the param groups
pgroup["params"] = origin_params pgroup["params"] = origin_params
param_groups.extend(new_groups.values()) for _, g in new_groups.items():
if g["params"]:
param_groups.append(g)
return tuple(param_groups) return tuple(param_groups)