diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 4e50549..211cb53 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -12,7 +12,6 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) Input Example: >>> ( >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx}, - >>> ..., >>> ) Returns: @@ -34,10 +33,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) elif not isinstance(param_groups, list): raise ValueError(f"Unknown param group type of {type(param_groups)}") - new_groups = [] + fp32_group = {"name": "fp32", "params": []} for pgroup in param_groups: - fp32_group = {"name": pgroup["name"] + "_fp32", "params": []} - # copy attribute from origin group + # copy attribute from origin group, we assume the input param_groups only + # have one group, so the attribute will not be copyed multiple times. for ori_key in pgroup.keys(): if ori_key not in ("name", "params"): fp32_group[ori_key] = pgroup[ori_key] @@ -48,12 +47,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) fp32_group["params"].append(param) else: origin_params.append(param) - # origin group without fp32 + # bf16 param group, the first group in the param_groups pgroup["params"] = origin_params - new_groups.append(fp32_group) - for g in new_groups: - param_groups.append(g) + param_groups.append(fp32_group) return tuple(param_groups)