refactor split_moe_group code

pull/182/head
Qu Wenwen 2023-09-15 16:55:16 +08:00
parent d13f5d3048
commit 5aa5c96ec8
1 changed files with 28 additions and 32 deletions

View File

@ -207,23 +207,23 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
group_moe = {} group_moe = {}
# Create the param MoE groups, leave param assign to next step # Create the param MoE groups, leave param assign to next step
for param_group in param_groups: for param_group in param_groups:
group_moe[param_group["name"]] = {}
for key in data_parallel_group_names: for key in data_parallel_group_names:
group_moe[param_group["name"]][key] = {} group_moe[key] = {}
group_moe[param_group["name"]][key]["name"] = key group_moe[key]["name"] = key
group_moe[param_group["name"]][key]["moe"] = True group_moe[key]["moe"] = True
for ori_key in param_group.keys(): for ori_key in param_group.keys():
if ori_key != "name": if ori_key != "name":
if ori_key == "params": if ori_key == "params":
group_moe[param_group["name"]][key][ori_key] = [] group_moe[key][ori_key] = []
else: else:
group_moe[param_group["name"]][key][ori_key] = param_group[ori_key] group_moe[key][ori_key] = param_group[ori_key]
# Assign param # Assign param
for param_group in param_groups: for param_group in param_groups:
new_params = [] new_params = []
for param in param_group["params"]: for param in param_group["params"]:
if is_moe_param(param): if is_moe_param(param):
group_moe[param_group["name"]][param.group_name]["params"].append(param) group_moe[param.group_name]["params"].append(param)
# param_group['params'].remove(param) # param_group['params'].remove(param)
else: else:
new_params.append(param) new_params.append(param)
@ -231,19 +231,17 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
# Flatten the moe groups # Flatten the moe groups
if max_group_size is not None: if max_group_size is not None:
for _, v in group_moe.items(): for _, v1 in group_moe.items():
for _, v1 in v.items():
cur_group = [] cur_group = []
all_groups = [] all_groups = []
size_of_cur_group = 0 size_of_cur_group = 0
for param in v1["params"]: for param in v1["params"]:
if size_of_cur_group + param.numel() <= max_group_size:
cur_group.append(param) cur_group.append(param)
size_of_cur_group += param.numel() size_of_cur_group += param.numel()
else: if size_of_cur_group > max_group_size:
all_groups.append(cur_group) all_groups.append(cur_group)
cur_group = [param] cur_group = []
size_of_cur_group = param.numel() size_of_cur_group = 0
if cur_group: if cur_group:
all_groups.append(cur_group) all_groups.append(cur_group)
for group in all_groups: for group in all_groups:
@ -254,10 +252,8 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
new_dict["params"] = group new_dict["params"] = group
param_groups.append(new_dict) param_groups.append(new_dict)
else: else:
for _, v in group_moe.items(): for _, v1 in group_moe.items():
for _, v1 in v.items():
param_groups.append(v1) param_groups.append(v1)
return tuple(param_groups) return tuple(param_groups)