refactor split_moe_group code

pull/182/head
Qu Wenwen 2023-09-15 16:55:16 +08:00
parent d13f5d3048
commit 5aa5c96ec8
1 changed files with 28 additions and 32 deletions

View File

@ -207,23 +207,23 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
group_moe = {}
# Create the param MoE groups, leave param assign to next step
for param_group in param_groups:
group_moe[param_group["name"]] = {}
for key in data_parallel_group_names:
group_moe[param_group["name"]][key] = {}
group_moe[param_group["name"]][key]["name"] = key
group_moe[param_group["name"]][key]["moe"] = True
group_moe[key] = {}
group_moe[key]["name"] = key
group_moe[key]["moe"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
group_moe[param_group["name"]][key][ori_key] = []
group_moe[key][ori_key] = []
else:
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key]
group_moe[key][ori_key] = param_group[ori_key]
# Assign param
for param_group in param_groups:
new_params = []
for param in param_group["params"]:
if is_moe_param(param):
group_moe[param_group["name"]][param.group_name]["params"].append(param)
group_moe[param.group_name]["params"].append(param)
# param_group['params'].remove(param)
else:
new_params.append(param)
@ -231,33 +231,29 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
# Flatten the moe groups
if max_group_size is not None:
for _, v in group_moe.items():
for _, v1 in v.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1["params"]:
if size_of_cur_group + param.numel() <= max_group_size:
cur_group.append(param)
size_of_cur_group += param.numel()
else:
all_groups.append(cur_group)
cur_group = [param]
size_of_cur_group = param.numel()
if cur_group:
for _, v1 in group_moe.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1["params"]:
cur_group.append(param)
size_of_cur_group += param.numel()
if size_of_cur_group > max_group_size:
all_groups.append(cur_group)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != "params":
new_dict[key] = val
new_dict["params"] = group
param_groups.append(new_dict)
cur_group = []
size_of_cur_group = 0
if cur_group:
all_groups.append(cur_group)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != "params":
new_dict[key] = val
new_dict["params"] = group
param_groups.append(new_dict)
else:
for _, v in group_moe.items():
for _, v1 in v.items():
param_groups.append(v1)
for _, v1 in group_moe.items():
param_groups.append(v1)
return tuple(param_groups)