From a5e80bcc569d04ff2cda36837fd4f6cbd6f0e60f Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 22 Sep 2023 15:24:28 +0800 Subject: [PATCH] refactor code for split param group --- internlm/train/utils.py | 59 +++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 8e4c1fa..8219e26 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -1,6 +1,9 @@ from typing import Dict, Tuple +import torch + from internlm.core.context.parallel_context import global_context as gpc +from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]: @@ -27,18 +30,6 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) >>> ) """ - def _get_group(param): - group_keys = ["is_expert", "is_gate", "is_norm"] - for i, key in enumerate(group_keys): - if hasattr(param, key) and getattr(param, key): - # experts param should return its group name - if i == 0: - return param.group_name - else: - return key[3:] - # TODO: deal with fp32 group - return None - if isinstance(param_groups, tuple): param_groups = list(param_groups) # Tuple cannot be modified elif isinstance(param_groups, dict): @@ -46,44 +37,42 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) elif not isinstance(param_groups, list): raise ValueError(f"Unknown param group type of {type(param_groups)}") - new_groups = [] + # create new groups for fp32, norm, moe gate and moe expert + new_groups = {} + new_groups["fp32"] = {"name": "fp32", "params": []} + for key in ["gate", "norm"]: + new_groups[key] = {"name": key, "sync_tp": True, "params": []} + for key in gpc.expert_parallel_group_names: + new_groups[key] = {"name": key, "moe": True, "params": []} + for pgroup in param_groups: - current_groups = {} - - # create new groups for gate and norm - for key in ["gate", "norm"]: - current_groups[key] = {"name": key, key: True, "params": []} - # create moe groups - for key in gpc.expert_parallel_group_names: - current_groups[key] = {"name": key, "moe": True, "params": []} - # copy attribute from origin group for ori_key in pgroup.keys(): if ori_key not in ("name", "params"): - for _, group in current_groups.items(): + for _, group in new_groups.items(): group[ori_key] = pgroup[ori_key] - - # Assign param + # assign param origin_params = [] + # first split the norm and gate groups, then the fp32 group, finally moe group for param in pgroup["params"]: - group = _get_group(param) - if group is not None: - current_groups[group]["params"].append(param) + if is_norm_param(param): + new_groups["norm"]["params"].append(param) + elif is_gate_param(param): + new_groups["gate"]["params"].append(param) + elif param.dtype == torch.float32: + new_groups["fp32"]["params"].append(param) + elif is_moe_param(param): + new_groups[param.group_name]["params"].append(param) else: origin_params.append(param) - + # bf16 param group, which is the first group in the param groups pgroup["params"] = origin_params - new_groups.append(current_groups) - - for g in new_groups: - for _, v in g.items(): - param_groups.append(v) + param_groups.extend(new_groups.values()) return tuple(param_groups) def create_param_groups(model, weight_decay): parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay} - return split_params_into_different_groups_for_optimizer(parameters)