refactor code for split group

pull/319/head
Wenwen Qu 2023-09-22 14:59:47 +08:00
parent b7229fd9fb
commit 72bb3125a3
1 changed files with 5 additions and 8 deletions

View File

@ -12,7 +12,6 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
Input Example:
>>> (
>>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
>>> ...,
>>> )
Returns:
@ -34,10 +33,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}")
new_groups = []
fp32_group = {"name": "fp32", "params": []}
for pgroup in param_groups:
fp32_group = {"name": pgroup["name"] + "_fp32", "params": []}
# copy attribute from origin group
# copy attribute from origin group, we assume the input param_groups only
# have one group, so the attribute will not be copyed multiple times.
for ori_key in pgroup.keys():
if ori_key not in ("name", "params"):
fp32_group[ori_key] = pgroup[ori_key]
@ -48,12 +47,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
fp32_group["params"].append(param)
else:
origin_params.append(param)
# origin group without fp32
# bf16 param group, the first group in the param_groups
pgroup["params"] = origin_params
new_groups.append(fp32_group)
for g in new_groups:
param_groups.append(g)
param_groups.append(fp32_group)
return tuple(param_groups)