diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index b0741e4..9bead52 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -152,9 +152,9 @@ class NaiveAMPModel(nn.Module): Set module to fp32 and register automatic conversion hook in the forward pass. The fp32 modules are marked by set_fp32_attr_to_module(.) """ - dtype = torch.float32 + fp32_dtype = torch.float32 - def to_dtype(x, dtype=dtype): + def to_dtype(x, dtype=fp32_dtype): if isinstance(x, Tensor) and x.dtype != dtype: return x.to(dtype) return x @@ -186,6 +186,6 @@ class NaiveAMPModel(nn.Module): # register_forward_pre_hook for transformer/embeding/norm/xxx block for sub_module in modules: if module_has_fp32_attr(sub_module): - sub_module.to(dtype) + sub_module.to(fp32_dtype) sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32)) diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 0e19398..2f4aa67 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -38,7 +38,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # create new groups for fp32, norm, moe gate and moe expert new_groups = {} new_groups["fp32"] = {"name": "fp32", "params": []} - if gpc.config.model.num_experts > 1: + if gpc.config.get("model_type") == "INTERNLM_MoE" and gpc.config.model.num_experts > 1: # norm and gate are special group to force sync (when enable MoE). for key in ["gate", "norm"]: new_groups[key] = {"name": key, key: True, "params": []} @@ -57,7 +57,11 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # first split the norm and gate groups, which are special case to force sync (when enable MoE), # then fp32 group and the moe group. for param in pgroup["params"]: - if gpc.config.model.num_experts > 1 and is_norm_param(param): + if ( + gpc.config.get("model_type") == "INTERNLM_MoE" + and gpc.config.model.num_experts > 1 + and is_norm_param(param) + ): new_groups["norm"]["params"].append(param) # gate param means MoE is enabled elif is_gate_param(param): @@ -73,7 +77,9 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # bf16 param group, which is the first group in the param groups pgroup["params"] = origin_params - param_groups.extend(new_groups.values()) + for _, g in new_groups.items(): + if g["params"]: + param_groups.append(g) return tuple(param_groups)