refactor split_moe_group code

pull/182/head
Qu Wenwen 2023-09-15 16:55:16 +08:00
parent d13f5d3048
commit 5aa5c96ec8
1 changed files with 28 additions and 32 deletions

View File

@ -207,23 +207,23 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
group_moe = {} group_moe = {}
# Create the param MoE groups, leave param assign to next step # Create the param MoE groups, leave param assign to next step
for param_group in param_groups: for param_group in param_groups:
group_moe[param_group["name"]] = {}
for key in data_parallel_group_names: for key in data_parallel_group_names:
group_moe[param_group["name"]][key] = {} group_moe[key] = {}
group_moe[param_group["name"]][key]["name"] = key group_moe[key]["name"] = key
group_moe[param_group["name"]][key]["moe"] = True group_moe[key]["moe"] = True
for ori_key in param_group.keys(): for ori_key in param_group.keys():
if ori_key != "name": if ori_key != "name":
if ori_key == "params": if ori_key == "params":
group_moe[param_group["name"]][key][ori_key] = [] group_moe[key][ori_key] = []
else: else:
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key] group_moe[key][ori_key] = param_group[ori_key]
# Assign param # Assign param
for param_group in param_groups: for param_group in param_groups:
new_params = [] new_params = []
for param in param_group["params"]: for param in param_group["params"]:
if is_moe_param(param): if is_moe_param(param):
group_moe[param_group["name"]][param.group_name]["params"].append(param) group_moe[param.group_name]["params"].append(param)
# param_group['params'].remove(param) # param_group['params'].remove(param)
else: else:
new_params.append(param) new_params.append(param)
@ -231,33 +231,29 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
# Flatten the moe groups # Flatten the moe groups
if max_group_size is not None: if max_group_size is not None:
for _, v in group_moe.items(): for _, v1 in group_moe.items():
for _, v1 in v.items(): cur_group = []
cur_group = [] all_groups = []
all_groups = [] size_of_cur_group = 0
size_of_cur_group = 0 for param in v1["params"]:
for param in v1["params"]: cur_group.append(param)
if size_of_cur_group + param.numel() <= max_group_size: size_of_cur_group += param.numel()
cur_group.append(param) if size_of_cur_group > max_group_size:
size_of_cur_group += param.numel()
else:
all_groups.append(cur_group)
cur_group = [param]
size_of_cur_group = param.numel()
if cur_group:
all_groups.append(cur_group) all_groups.append(cur_group)
for group in all_groups: cur_group = []
new_dict = {} size_of_cur_group = 0
for key, val in v1.items(): if cur_group:
if key != "params": all_groups.append(cur_group)
new_dict[key] = val for group in all_groups:
new_dict["params"] = group new_dict = {}
param_groups.append(new_dict) for key, val in v1.items():
if key != "params":
new_dict[key] = val
new_dict["params"] = group
param_groups.append(new_dict)
else: else:
for _, v in group_moe.items(): for _, v1 in group_moe.items():
for _, v1 in v.items(): param_groups.append(v1)
param_groups.append(v1)
return tuple(param_groups) return tuple(param_groups)