mirror of https://github.com/InternLM/InternLM
refactor code for split group
parent
b7229fd9fb
commit
72bb3125a3
|
@ -12,7 +12,6 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
Input Example:
|
||||
>>> (
|
||||
>>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
|
||||
>>> ...,
|
||||
>>> )
|
||||
|
||||
Returns:
|
||||
|
@ -34,10 +33,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
elif not isinstance(param_groups, list):
|
||||
raise ValueError(f"Unknown param group type of {type(param_groups)}")
|
||||
|
||||
new_groups = []
|
||||
fp32_group = {"name": "fp32", "params": []}
|
||||
for pgroup in param_groups:
|
||||
fp32_group = {"name": pgroup["name"] + "_fp32", "params": []}
|
||||
# copy attribute from origin group
|
||||
# copy attribute from origin group, we assume the input param_groups only
|
||||
# have one group, so the attribute will not be copyed multiple times.
|
||||
for ori_key in pgroup.keys():
|
||||
if ori_key not in ("name", "params"):
|
||||
fp32_group[ori_key] = pgroup[ori_key]
|
||||
|
@ -48,12 +47,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
fp32_group["params"].append(param)
|
||||
else:
|
||||
origin_params.append(param)
|
||||
# origin group without fp32
|
||||
# bf16 param group, the first group in the param_groups
|
||||
pgroup["params"] = origin_params
|
||||
new_groups.append(fp32_group)
|
||||
|
||||
for g in new_groups:
|
||||
param_groups.append(g)
|
||||
param_groups.append(fp32_group)
|
||||
|
||||
return tuple(param_groups)
|
||||
|
||||
|
|
Loading…
Reference in New Issue