From 4a478723820d14e43941719a8b20388aedc02b7e Mon Sep 17 00:00:00 2001 From: Qu Wenwen Date: Tue, 19 Sep 2023 12:30:40 +0800 Subject: [PATCH] refactor code --- internlm/core/context/parallel_context.py | 5 + internlm/model/moe.py | 150 +----------------- internlm/model/utils.py | 18 +++ .../solver/optimizer/hybrid_zero_optim.py | 2 +- internlm/train/training_internlm.py | 4 +- internlm/train/utils.py | 73 +++++++++ internlm/utils/parallel.py | 2 +- 7 files changed, 103 insertions(+), 151 deletions(-) create mode 100644 internlm/train/utils.py diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index ef1463c..d2f8412 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -149,11 +149,16 @@ class ParallelContext(metaclass=SingletonMeta): self.num_processes_on_current_node = -1 self.virtual_pipeline_parallel_size = None self.virtual_pipeline_parallel_rank = None + self._expert_parallel_group_names = [] @property def config(self): return self._config + @property + def expert_parallel_group_names(self): + return self._expert_parallel_group_names + def load_config(self, config: Union[dict, str]): """Loads the configuration from either a dict or a file. diff --git a/internlm/model/moe.py b/internlm/model/moe.py index e102df9..b116937 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -1,5 +1,4 @@ import typing -from typing import Dict, Tuple import torch @@ -20,36 +19,6 @@ from internlm.utils.logger import get_logger logger = get_logger(__file__) -def has_moe_layers(m): - has_moe = False - num_experts = 0 - - for _, module in m.named_modules(): - if isinstance(module, MoE): - has_moe = True - num_experts = module.num_experts - break - return has_moe, num_experts - - -def is_moe_param(param: torch.Tensor) -> bool: - if hasattr(param, "is_expert") and param.is_expert: - return True - return False - - -def is_gate_param(param: torch.Tensor) -> bool: - if hasattr(param, "is_gate") and param.is_gate: - return True - return False - - -def is_norm_param(param: torch.Tensor) -> bool: - if hasattr(param, "is_norm") and param.is_norm: - return True - return False - - class MoE(torch.nn.Module): """Initialize an MoE layer. @@ -110,7 +79,9 @@ class MoE(torch.nn.Module): ) # for elastic expert paralle, experts may have multiple groups - expert_group_name = f"ep_size_{self.ep_size}" + expert_group_name = f"moe_ep_size_{self.ep_size}" + if expert_group_name not in gpc.expert_parallel_group_names: + gpc.expert_parallel_group_names.append(expert_group_name) experts = torch.nn.ModuleList( [ # TODO have trouble when use internlm.model.linear.FeedForward @@ -188,118 +159,3 @@ class MoE(torch.nn.Module): coef = torch.nn.functional.softmax(coef, dim=-1) output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] return output, self.moe_layer.l_aux, self.moe_layer.exp_counts - - -def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], max_group_size=None) -> Tuple[Dict]: - """Split parameters into different MoE groups for optimizer - Compatiable with muiltiple param groups, each should have a name - - Args: - param_groups (Tuple[Dict]): - The list of parameter groups to split - - Returns: - Tuple[Dict]: - list of MoE/non-MoE groups for optimizer - """ - if isinstance(param_groups, tuple): - param_groups = list(param_groups) # Tuple cannot be modified - elif isinstance(param_groups, dict): - param_groups = [param_groups] - elif not isinstance(param_groups, list): - raise ValueError(f"Unknown param group type of {type(param_groups)}") - - # gather all data parallel group names - data_parallel_group_names = set() - for param_group in param_groups: - for param in param_group["params"]: - if is_moe_param(param): - data_parallel_group_names.add(param.group_name) - data_parallel_group_names = list(data_parallel_group_names) - group_moe = {} - gate_group = {} - norm_group = {} - # Create the param MoE groups, leave param assign to next step - for param_group in param_groups: - group_moe[param_group["name"]] = {} - for key in data_parallel_group_names: - group_moe[param_group["name"]][key] = {} - group_moe[param_group["name"]][key]["name"] = key - group_moe[param_group["name"]][key]["moe"] = True - for ori_key in param_group.keys(): - if ori_key != "name": - if ori_key == "params": - group_moe[param_group["name"]][key][ori_key] = [] - else: - group_moe[param_group["name"]][key][ori_key] = param_group[ori_key] - gate_group["name"] = "gate" - gate_group["gate"] = True - for ori_key in param_group.keys(): - if ori_key != "name": - if ori_key == "params": - gate_group[ori_key] = [] - else: - gate_group[ori_key] = param_group[ori_key] - norm_group["name"] = "norm" - norm_group["norm"] = True - for ori_key in param_group.keys(): - if ori_key != "name": - if ori_key == "params": - norm_group[ori_key] = [] - else: - norm_group[ori_key] = param_group[ori_key] - # Assign param - norm_params = [] - gate_params = [] - for param_group in param_groups: - new_params = [] - for param in param_group["params"]: - if is_moe_param(param): - group_moe[param_group["name"]][param.group_name]["params"].append(param) - elif is_norm_param(param): - norm_params.append(param) - elif is_gate_param(param): - gate_params.append(param) - else: - new_params.append(param) - param_group["params"] = new_params - norm_group["params"] = norm_params - gate_group["params"] = gate_params - param_groups.append(norm_group) - param_groups.append(gate_group) - - # Flatten the moe groups - if max_group_size is not None: - for _, v in group_moe.items(): - for _, v1 in v.items(): - cur_group = [] - all_groups = [] - size_of_cur_group = 0 - for param in v1["params"]: - if size_of_cur_group + param.numel() <= max_group_size: - cur_group.append(param) - size_of_cur_group += param.numel() - else: - all_groups.append(cur_group) - cur_group = [param] - size_of_cur_group = param.numel() - if cur_group: - all_groups.append(cur_group) - for group in all_groups: - new_dict = {} - for key, val in v1.items(): - if key != "params": - new_dict[key] = val - new_dict["params"] = group - param_groups.append(new_dict) - else: - for _, v in group_moe.items(): - for _, v1 in v.items(): - param_groups.append(v1) - return tuple(param_groups) - - -def create_moe_param_groups(model, weight_decay): - parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay} - - return split_params_into_different_moe_groups_for_optimizer(parameters) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 12f80e3..bb887c3 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -207,3 +207,21 @@ def try_import_RMSNorm(): from internlm.model.norm import RMSNormTorch as RMSNorm return RMSNorm + + +def is_moe_param(param: torch.Tensor) -> bool: + if hasattr(param, "is_expert") and param.is_expert: + return True + return False + + +def is_gate_param(param: torch.Tensor) -> bool: + if hasattr(param, "is_gate") and param.is_gate: + return True + return False + + +def is_norm_param(param: torch.Tensor) -> bool: + if hasattr(param, "is_norm") and param.is_norm: + return True + return False diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 80aaa7b..3fe3eb5 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,7 +11,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.moe import is_moe_param +from internlm.model.utils import is_moe_param from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index f3942af..30cc9e0 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -25,13 +25,13 @@ from internlm.data.packed_dataset import ( get_packed_dataset_without_short_length, ) from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data -from internlm.model.moe import create_moe_param_groups from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler +from internlm.train.utils import create_param_groups from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -112,7 +112,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): adam_cfg = gpc.config.adam # split the moe parameters into different groups if gpc.config.model.num_experts > 1: - params = create_moe_param_groups(model, adam_cfg.weight_decay) + params = create_param_groups(model, adam_cfg.weight_decay) else: params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}] naive_optimizer = torch.optim.AdamW( diff --git a/internlm/train/utils.py b/internlm/train/utils.py new file mode 100644 index 0000000..f91f1f1 --- /dev/null +++ b/internlm/train/utils.py @@ -0,0 +1,73 @@ +from typing import Dict, Tuple + +from internlm.core.context.parallel_context import global_context as gpc +from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param + + +def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]: + """Split parameters into different MoE groups for optimizer + Compatiable with muiltiple param groups, each should have a name + + Args: + param_groups (Tuple[Dict]): + The list of parameter groups to split + + Returns: + Tuple[Dict]: + list of MoE/non-MoE groups for optimizer + """ + if isinstance(param_groups, tuple): + param_groups = list(param_groups) # Tuple cannot be modified + elif isinstance(param_groups, dict): + param_groups = [param_groups] + elif not isinstance(param_groups, list): + raise ValueError(f"Unknown param group type of {type(param_groups)}") + + new_groups = {} + for pgroup in param_groups: + new_groups[pgroup["name"]] = {} + + # create new groups for gate and norm + for key in ["gate", "norm"]: + new_groups[pgroup["name"]][key] = {} + new_groups[pgroup["name"]][key]["name"] = key + new_groups[pgroup["name"]][key][key] = True + # create moe groups + for key in gpc.expert_parallel_group_names: + new_groups[pgroup["name"]][key] = {} + new_groups[pgroup["name"]][key]["name"] = key + new_groups[pgroup["name"]][key]["moe"] = True + + # copy attribute from origin group + for ori_key in pgroup.keys(): + for key in new_groups[pgroup["name"]].keys(): + if ori_key != "name": + if ori_key == "params": + new_groups[pgroup["name"]][key][ori_key] = [] + else: + new_groups[pgroup["name"]][key][ori_key] = pgroup[ori_key] + # Assign param + origin_params = [] + for param in pgroup["params"]: + if is_moe_param(param): + new_groups[pgroup["name"]][param.group_name]["params"].append(param) + elif is_norm_param(param): + new_groups[pgroup["name"]]["norm"]["params"].append(param) + elif is_gate_param(param): + new_groups[pgroup["name"]]["gate"]["params"].append(param) + else: + origin_params.append(param) + + pgroup["params"] = origin_params + + for _, v in new_groups.items(): + for _, v1 in v.items(): + param_groups.append(v1) + + return tuple(param_groups) + + +def create_param_groups(model, weight_decay): + parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay} + + return split_params_into_different_groups_for_optimizer(parameters) diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 5df51d1..b7e3b86 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -5,7 +5,7 @@ import torch.distributed as dist from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.moe import is_moe_param +from internlm.model.utils import is_moe_param def is_model_parallel_parameter(p):