mirror of https://github.com/InternLM/InternLM
fix merge bugs
parent
8a63cb51ef
commit
2a70262ceb
|
@ -152,9 +152,9 @@ class NaiveAMPModel(nn.Module):
|
||||||
Set module to fp32 and register automatic conversion hook in the forward pass.
|
Set module to fp32 and register automatic conversion hook in the forward pass.
|
||||||
The fp32 modules are marked by set_fp32_attr_to_module(.)
|
The fp32 modules are marked by set_fp32_attr_to_module(.)
|
||||||
"""
|
"""
|
||||||
dtype = torch.float32
|
fp32_dtype = torch.float32
|
||||||
|
|
||||||
def to_dtype(x, dtype=dtype):
|
def to_dtype(x, dtype=fp32_dtype):
|
||||||
if isinstance(x, Tensor) and x.dtype != dtype:
|
if isinstance(x, Tensor) and x.dtype != dtype:
|
||||||
return x.to(dtype)
|
return x.to(dtype)
|
||||||
return x
|
return x
|
||||||
|
@ -186,6 +186,6 @@ class NaiveAMPModel(nn.Module):
|
||||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||||
for sub_module in modules:
|
for sub_module in modules:
|
||||||
if module_has_fp32_attr(sub_module):
|
if module_has_fp32_attr(sub_module):
|
||||||
sub_module.to(dtype)
|
sub_module.to(fp32_dtype)
|
||||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
||||||
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
|
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
|
||||||
|
|
|
@ -38,7 +38,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
# create new groups for fp32, norm, moe gate and moe expert
|
# create new groups for fp32, norm, moe gate and moe expert
|
||||||
new_groups = {}
|
new_groups = {}
|
||||||
new_groups["fp32"] = {"name": "fp32", "params": []}
|
new_groups["fp32"] = {"name": "fp32", "params": []}
|
||||||
if gpc.config.model.num_experts > 1:
|
if gpc.config.get("model_type") == "INTERNLM_MoE" and gpc.config.model.num_experts > 1:
|
||||||
# norm and gate are special group to force sync (when enable MoE).
|
# norm and gate are special group to force sync (when enable MoE).
|
||||||
for key in ["gate", "norm"]:
|
for key in ["gate", "norm"]:
|
||||||
new_groups[key] = {"name": key, key: True, "params": []}
|
new_groups[key] = {"name": key, key: True, "params": []}
|
||||||
|
@ -57,7 +57,11 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
|
# first split the norm and gate groups, which are special case to force sync (when enable MoE),
|
||||||
# then fp32 group and the moe group.
|
# then fp32 group and the moe group.
|
||||||
for param in pgroup["params"]:
|
for param in pgroup["params"]:
|
||||||
if gpc.config.model.num_experts > 1 and is_norm_param(param):
|
if (
|
||||||
|
gpc.config.get("model_type") == "INTERNLM_MoE"
|
||||||
|
and gpc.config.model.num_experts > 1
|
||||||
|
and is_norm_param(param)
|
||||||
|
):
|
||||||
new_groups["norm"]["params"].append(param)
|
new_groups["norm"]["params"].append(param)
|
||||||
# gate param means MoE is enabled
|
# gate param means MoE is enabled
|
||||||
elif is_gate_param(param):
|
elif is_gate_param(param):
|
||||||
|
@ -73,7 +77,9 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
# bf16 param group, which is the first group in the param groups
|
# bf16 param group, which is the first group in the param groups
|
||||||
pgroup["params"] = origin_params
|
pgroup["params"] = origin_params
|
||||||
|
|
||||||
param_groups.extend(new_groups.values())
|
for _, g in new_groups.items():
|
||||||
|
if g["params"]:
|
||||||
|
param_groups.append(g)
|
||||||
|
|
||||||
return tuple(param_groups)
|
return tuple(param_groups)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue