mirror of https://github.com/InternLM/InternLM
refactor code for split group
parent
883160a558
commit
89d0373a8c
|
@ -8,12 +8,24 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
Compatiable with muiltiple param groups, each should have a name
|
Compatiable with muiltiple param groups, each should have a name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
param_groups (Tuple[Dict]):
|
param_groups (Tuple[Dict]): The list of parameter groups to split
|
||||||
The list of parameter groups to split
|
Input Example:
|
||||||
|
>>> (
|
||||||
|
>>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
|
||||||
|
>>> ...,
|
||||||
|
>>> )
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict]:
|
Tuple[Dict]: list of params groups for optimizer
|
||||||
list of fp16/fp32 groups for optimizer
|
Output Example:
|
||||||
|
>>> (
|
||||||
|
>>> {'name': 'default','params': [tensor],'weight_decay' :xxx},
|
||||||
|
>>> {'name': 'default_fp32', 'params': [tensor],'weight_decay' :xxx},
|
||||||
|
>>> ...,
|
||||||
|
>>> )
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict]: list of fp16/fp32 groups for optimizer
|
||||||
"""
|
"""
|
||||||
if isinstance(param_groups, tuple):
|
if isinstance(param_groups, tuple):
|
||||||
param_groups = list(param_groups) # Tuple cannot be modified
|
param_groups = list(param_groups) # Tuple cannot be modified
|
||||||
|
@ -22,22 +34,13 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
elif not isinstance(param_groups, list):
|
elif not isinstance(param_groups, list):
|
||||||
raise ValueError(f"Unknown param group type of {type(param_groups)}")
|
raise ValueError(f"Unknown param group type of {type(param_groups)}")
|
||||||
|
|
||||||
new_groups = {}
|
new_groups = []
|
||||||
for pgroup in param_groups:
|
for pgroup in param_groups:
|
||||||
new_groups[pgroup["name"]] = {}
|
fp32_group = {"name": pgroup["name"] + "_fp32", "params": []}
|
||||||
new_groups[pgroup["name"]]["fp32"] = {}
|
# copy attribute from origin group
|
||||||
|
|
||||||
# Create fp32 groups and copy origin attribute
|
|
||||||
fp32_group = new_groups[pgroup["name"]]["fp32"]
|
|
||||||
fp32_group["name"] = pgroup["name"] + "_fp32"
|
|
||||||
# copy attribute for fp32 group
|
|
||||||
for ori_key in pgroup.keys():
|
for ori_key in pgroup.keys():
|
||||||
if ori_key != "name":
|
if ori_key not in ("name", "params"):
|
||||||
if ori_key == "params":
|
fp32_group[ori_key] = pgroup[ori_key]
|
||||||
fp32_group[ori_key] = []
|
|
||||||
else:
|
|
||||||
fp32_group[ori_key] = pgroup[ori_key]
|
|
||||||
|
|
||||||
# Assign param
|
# Assign param
|
||||||
origin_params = []
|
origin_params = []
|
||||||
for param in pgroup["params"]:
|
for param in pgroup["params"]:
|
||||||
|
@ -45,13 +48,12 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
fp32_group["params"].append(param)
|
fp32_group["params"].append(param)
|
||||||
else:
|
else:
|
||||||
origin_params.append(param)
|
origin_params.append(param)
|
||||||
|
|
||||||
# origin group without fp32
|
# origin group without fp32
|
||||||
pgroup["params"] = origin_params
|
pgroup["params"] = origin_params
|
||||||
|
new_groups.append(fp32_group)
|
||||||
|
|
||||||
for _, v in new_groups.items():
|
for g in new_groups:
|
||||||
for _, v1 in v.items():
|
param_groups.append(g)
|
||||||
param_groups.append(v1)
|
|
||||||
|
|
||||||
return tuple(param_groups)
|
return tuple(param_groups)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue