diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 631b85f..6d02779 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -207,23 +207,23 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic group_moe = {} # Create the param MoE groups, leave param assign to next step for param_group in param_groups: - group_moe[param_group["name"]] = {} for key in data_parallel_group_names: - group_moe[param_group["name"]][key] = {} - group_moe[param_group["name"]][key]["name"] = key - group_moe[param_group["name"]][key]["moe"] = True + group_moe[key] = {} + group_moe[key]["name"] = key + group_moe[key]["moe"] = True for ori_key in param_group.keys(): if ori_key != "name": if ori_key == "params": - group_moe[param_group["name"]][key][ori_key] = [] + group_moe[key][ori_key] = [] else: - group_moe[param_group["name"]][key][ori_key] = param_group[ori_key] + group_moe[key][ori_key] = param_group[ori_key] + # Assign param for param_group in param_groups: new_params = [] for param in param_group["params"]: if is_moe_param(param): - group_moe[param_group["name"]][param.group_name]["params"].append(param) + group_moe[param.group_name]["params"].append(param) # param_group['params'].remove(param) else: new_params.append(param) @@ -231,33 +231,29 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic # Flatten the moe groups if max_group_size is not None: - for _, v in group_moe.items(): - for _, v1 in v.items(): - cur_group = [] - all_groups = [] - size_of_cur_group = 0 - for param in v1["params"]: - if size_of_cur_group + param.numel() <= max_group_size: - cur_group.append(param) - size_of_cur_group += param.numel() - else: - all_groups.append(cur_group) - cur_group = [param] - size_of_cur_group = param.numel() - if cur_group: + for _, v1 in group_moe.items(): + cur_group = [] + all_groups = [] + size_of_cur_group = 0 + for param in v1["params"]: + cur_group.append(param) + size_of_cur_group += param.numel() + if size_of_cur_group > max_group_size: all_groups.append(cur_group) - for group in all_groups: - new_dict = {} - for key, val in v1.items(): - if key != "params": - new_dict[key] = val - new_dict["params"] = group - param_groups.append(new_dict) + cur_group = [] + size_of_cur_group = 0 + if cur_group: + all_groups.append(cur_group) + for group in all_groups: + new_dict = {} + for key, val in v1.items(): + if key != "params": + new_dict[key] = val + new_dict["params"] = group + param_groups.append(new_dict) else: - for _, v in group_moe.items(): - for _, v1 in v.items(): - param_groups.append(v1) - + for _, v1 in group_moe.items(): + param_groups.append(v1) return tuple(param_groups)