mirror of https://github.com/InternLM/InternLM
refactor split_moe_group code
parent
d13f5d3048
commit
5aa5c96ec8
|
@ -207,23 +207,23 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
|
|||
group_moe = {}
|
||||
# Create the param MoE groups, leave param assign to next step
|
||||
for param_group in param_groups:
|
||||
group_moe[param_group["name"]] = {}
|
||||
for key in data_parallel_group_names:
|
||||
group_moe[param_group["name"]][key] = {}
|
||||
group_moe[param_group["name"]][key]["name"] = key
|
||||
group_moe[param_group["name"]][key]["moe"] = True
|
||||
group_moe[key] = {}
|
||||
group_moe[key]["name"] = key
|
||||
group_moe[key]["moe"] = True
|
||||
for ori_key in param_group.keys():
|
||||
if ori_key != "name":
|
||||
if ori_key == "params":
|
||||
group_moe[param_group["name"]][key][ori_key] = []
|
||||
group_moe[key][ori_key] = []
|
||||
else:
|
||||
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key]
|
||||
group_moe[key][ori_key] = param_group[ori_key]
|
||||
|
||||
# Assign param
|
||||
for param_group in param_groups:
|
||||
new_params = []
|
||||
for param in param_group["params"]:
|
||||
if is_moe_param(param):
|
||||
group_moe[param_group["name"]][param.group_name]["params"].append(param)
|
||||
group_moe[param.group_name]["params"].append(param)
|
||||
# param_group['params'].remove(param)
|
||||
else:
|
||||
new_params.append(param)
|
||||
|
@ -231,33 +231,29 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
|
|||
|
||||
# Flatten the moe groups
|
||||
if max_group_size is not None:
|
||||
for _, v in group_moe.items():
|
||||
for _, v1 in v.items():
|
||||
cur_group = []
|
||||
all_groups = []
|
||||
size_of_cur_group = 0
|
||||
for param in v1["params"]:
|
||||
if size_of_cur_group + param.numel() <= max_group_size:
|
||||
cur_group.append(param)
|
||||
size_of_cur_group += param.numel()
|
||||
else:
|
||||
all_groups.append(cur_group)
|
||||
cur_group = [param]
|
||||
size_of_cur_group = param.numel()
|
||||
if cur_group:
|
||||
for _, v1 in group_moe.items():
|
||||
cur_group = []
|
||||
all_groups = []
|
||||
size_of_cur_group = 0
|
||||
for param in v1["params"]:
|
||||
cur_group.append(param)
|
||||
size_of_cur_group += param.numel()
|
||||
if size_of_cur_group > max_group_size:
|
||||
all_groups.append(cur_group)
|
||||
for group in all_groups:
|
||||
new_dict = {}
|
||||
for key, val in v1.items():
|
||||
if key != "params":
|
||||
new_dict[key] = val
|
||||
new_dict["params"] = group
|
||||
param_groups.append(new_dict)
|
||||
cur_group = []
|
||||
size_of_cur_group = 0
|
||||
if cur_group:
|
||||
all_groups.append(cur_group)
|
||||
for group in all_groups:
|
||||
new_dict = {}
|
||||
for key, val in v1.items():
|
||||
if key != "params":
|
||||
new_dict[key] = val
|
||||
new_dict["params"] = group
|
||||
param_groups.append(new_dict)
|
||||
else:
|
||||
for _, v in group_moe.items():
|
||||
for _, v1 in v.items():
|
||||
param_groups.append(v1)
|
||||
|
||||
for _, v1 in group_moe.items():
|
||||
param_groups.append(v1)
|
||||
return tuple(param_groups)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue