From 72bb3125a31d968ad48d78a9c4d8f8e0bc1bd2b0 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 22 Sep 2023 14:59:47 +0800 Subject: [PATCH] refactor code for split group --- internlm/train/utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 4e50549..211cb53 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -12,7 +12,6 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) Input Example: >>> ( >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx}, - >>> ..., >>> ) Returns: @@ -34,10 +33,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) elif not isinstance(param_groups, list): raise ValueError(f"Unknown param group type of {type(param_groups)}") - new_groups = [] + fp32_group = {"name": "fp32", "params": []} for pgroup in param_groups: - fp32_group = {"name": pgroup["name"] + "_fp32", "params": []} - # copy attribute from origin group + # copy attribute from origin group, we assume the input param_groups only + # have one group, so the attribute will not be copyed multiple times. for ori_key in pgroup.keys(): if ori_key not in ("name", "params"): fp32_group[ori_key] = pgroup[ori_key] @@ -48,12 +47,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) fp32_group["params"].append(param) else: origin_params.append(param) - # origin group without fp32 + # bf16 param group, the first group in the param_groups pgroup["params"] = origin_params - new_groups.append(fp32_group) - for g in new_groups: - param_groups.append(g) + param_groups.append(fp32_group) return tuple(param_groups)