refactor code for split group

pull/319/head
Wenwen Qu 2023-09-22 13:49:17 +08:00
parent 883160a558
commit 89d0373a8c
1 changed files with 24 additions and 22 deletions

View File

@ -8,12 +8,24 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
Compatiable with muiltiple param groups, each should have a name Compatiable with muiltiple param groups, each should have a name
Args: Args:
param_groups (Tuple[Dict]): param_groups (Tuple[Dict]): The list of parameter groups to split
The list of parameter groups to split Input Example:
>>> (
>>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
>>> ...,
>>> )
Returns: Returns:
Tuple[Dict]: Tuple[Dict]: list of params groups for optimizer
list of fp16/fp32 groups for optimizer Output Example:
>>> (
>>> {'name': 'default','params': [tensor],'weight_decay' :xxx},
>>> {'name': 'default_fp32', 'params': [tensor],'weight_decay' :xxx},
>>> ...,
>>> )
Returns:
Tuple[Dict]: list of fp16/fp32 groups for optimizer
""" """
if isinstance(param_groups, tuple): if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified param_groups = list(param_groups) # Tuple cannot be modified
@ -22,22 +34,13 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
elif not isinstance(param_groups, list): elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}") raise ValueError(f"Unknown param group type of {type(param_groups)}")
new_groups = {} new_groups = []
for pgroup in param_groups: for pgroup in param_groups:
new_groups[pgroup["name"]] = {} fp32_group = {"name": pgroup["name"] + "_fp32", "params": []}
new_groups[pgroup["name"]]["fp32"] = {} # copy attribute from origin group
# Create fp32 groups and copy origin attribute
fp32_group = new_groups[pgroup["name"]]["fp32"]
fp32_group["name"] = pgroup["name"] + "_fp32"
# copy attribute for fp32 group
for ori_key in pgroup.keys(): for ori_key in pgroup.keys():
if ori_key != "name": if ori_key not in ("name", "params"):
if ori_key == "params": fp32_group[ori_key] = pgroup[ori_key]
fp32_group[ori_key] = []
else:
fp32_group[ori_key] = pgroup[ori_key]
# Assign param # Assign param
origin_params = [] origin_params = []
for param in pgroup["params"]: for param in pgroup["params"]:
@ -45,13 +48,12 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
fp32_group["params"].append(param) fp32_group["params"].append(param)
else: else:
origin_params.append(param) origin_params.append(param)
# origin group without fp32 # origin group without fp32
pgroup["params"] = origin_params pgroup["params"] = origin_params
new_groups.append(fp32_group)
for _, v in new_groups.items(): for g in new_groups:
for _, v1 in v.items(): param_groups.append(g)
param_groups.append(v1)
return tuple(param_groups) return tuple(param_groups)