From 89d0373a8c2fa81baf92dd87ab6ccd323d72fbc4 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 22 Sep 2023 13:49:17 +0800 Subject: [PATCH] refactor code for split group --- internlm/train/utils.py | 46 +++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 256981a..4e50549 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -8,12 +8,24 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) Compatiable with muiltiple param groups, each should have a name Args: - param_groups (Tuple[Dict]): - The list of parameter groups to split + param_groups (Tuple[Dict]): The list of parameter groups to split + Input Example: + >>> ( + >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx}, + >>> ..., + >>> ) Returns: - Tuple[Dict]: - list of fp16/fp32 groups for optimizer + Tuple[Dict]: list of params groups for optimizer + Output Example: + >>> ( + >>> {'name': 'default','params': [tensor],'weight_decay' :xxx}, + >>> {'name': 'default_fp32', 'params': [tensor],'weight_decay' :xxx}, + >>> ..., + >>> ) + + Returns: + Tuple[Dict]: list of fp16/fp32 groups for optimizer """ if isinstance(param_groups, tuple): param_groups = list(param_groups) # Tuple cannot be modified @@ -22,22 +34,13 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) elif not isinstance(param_groups, list): raise ValueError(f"Unknown param group type of {type(param_groups)}") - new_groups = {} + new_groups = [] for pgroup in param_groups: - new_groups[pgroup["name"]] = {} - new_groups[pgroup["name"]]["fp32"] = {} - - # Create fp32 groups and copy origin attribute - fp32_group = new_groups[pgroup["name"]]["fp32"] - fp32_group["name"] = pgroup["name"] + "_fp32" - # copy attribute for fp32 group + fp32_group = {"name": pgroup["name"] + "_fp32", "params": []} + # copy attribute from origin group for ori_key in pgroup.keys(): - if ori_key != "name": - if ori_key == "params": - fp32_group[ori_key] = [] - else: - fp32_group[ori_key] = pgroup[ori_key] - + if ori_key not in ("name", "params"): + fp32_group[ori_key] = pgroup[ori_key] # Assign param origin_params = [] for param in pgroup["params"]: @@ -45,13 +48,12 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) fp32_group["params"].append(param) else: origin_params.append(param) - # origin group without fp32 pgroup["params"] = origin_params + new_groups.append(fp32_group) - for _, v in new_groups.items(): - for _, v1 in v.items(): - param_groups.append(v1) + for g in new_groups: + param_groups.append(g) return tuple(param_groups)