diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 2e931ea..daab2bb 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -88,7 +88,7 @@ class NonPipelineScheduler(BaseScheduler): forward_only: bool = False, return_loss: bool = True, scale_loss: int = 1, - moe_loss_coeff: float = 1.0, + moe_loss_coeff: float = 0.01, ): """Trains one batch of data. @@ -136,7 +136,7 @@ class NonPipelineScheduler(BaseScheduler): forward_only: bool = False, return_loss: bool = True, return_output_label: bool = True, - moe_loss_coeff: float = 1.0, + moe_loss_coeff: float = 0.01, ): """The process function that loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 04a39bd..75beb14 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -1,4 +1,5 @@ import typing +from typing import Dict, Tuple import torch @@ -31,7 +32,7 @@ def has_moe_layers(m): def is_moe_param(param: torch.Tensor) -> bool: - if hasattr(param, "allreduce") and not param.allreduce: + if hasattr(param, "all_reduce") and not param.all_reduce: return True return False @@ -95,7 +96,8 @@ class MoE(torch.nn.Module): "Unsupported noisy_gate_policy: " + noisy_gate_policy ) - experts = Experts(experts, self.num_local_experts) + expert_group_name = f"ep_size_{self.ep_size}" + experts = Experts(experts, self.num_local_experts, expert_group_name) if using_default_moe: self.moe_layer = MOELayer( @@ -148,3 +150,94 @@ class MoE(torch.nn.Module): coef = torch.nn.functional.softmax(coef, dim=-1) output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] return output, self.moe_layer.l_aux, self.moe_layer.exp_counts + + +def split_params_into_different_moe_groups_for_optimizer( + param_groups: Tuple[Dict], max_group_size=178956971 +) -> Tuple[Dict]: + """Split parameters into different MoE groups for optimizer + Compatiable with muiltiple param groups, each should have a name + + Args: + param_groups (Tuple[Dict]): + The list of parameter groups to split + + Returns: + Tuple[Dict]: + list of MoE/non-MoE groups for optimizer + """ + if isinstance(param_groups, tuple): + param_groups = list(param_groups) # Tuple cannot be modified + elif isinstance(param_groups, dict): + param_groups = [param_groups] + elif not isinstance(param_groups, list): + raise ValueError(f"Unknown param group type of {type(param_groups)}") + + # gather all data parallel group names + data_parallel_group_names = set() + for param_group in param_groups: + for param in param_group["params"]: + if is_moe_param(param): + data_parallel_group_names.add(param.group_name) + data_parallel_group_names = list(data_parallel_group_names) + group_moe = {} + # Create the param MoE groups, leave param assign to next step + for param_group in param_groups: + group_moe[param_group["name"]] = {} + for key in data_parallel_group_names: + group_moe[param_group["name"]][key] = {} + group_moe[param_group["name"]][key]["name"] = key + group_moe[param_group["name"]][key]["moe"] = True + for ori_key in param_group.keys(): + if ori_key != "name": + if ori_key == "params": + group_moe[param_group["name"]][key][ori_key] = [] + else: + group_moe[param_group["name"]][key][ori_key] = param_group[ori_key] + # Assign param + for param_group in param_groups: + new_params = [] + for param in param_group["params"]: + if is_moe_param(param): + group_moe[param_group["name"]][param.group_name]["params"].append(param) + # param_group['params'].remove(param) + else: + new_params.append(param) + param_group["params"] = new_params + + # Flatten the moe groups + if max_group_size is not None: + for _, v in group_moe.items(): + for _, v1 in v.items(): + cur_group = [] + all_groups = [] + size_of_cur_group = 0 + for param in v1["params"]: + if size_of_cur_group + param.numel() <= max_group_size: + cur_group.append(param) + size_of_cur_group += param.numel() + else: + all_groups.append(cur_group) + cur_group = [param] + size_of_cur_group = param.numel() + if cur_group: + all_groups.append(cur_group) + for group in all_groups: + new_dict = {} + for key, val in v1.items(): + if key != "params": + new_dict[key] = val + new_dict["params"] = group + param_groups.append(new_dict) + else: + for _, v in group_moe.items(): + for _, v1 in v.items(): + param_groups.append(v1) + + return tuple(param_groups) + + +def create_moe_param_groups(model, weight_decay): + parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay} + + return split_params_into_different_moe_groups_for_optimizer(parameters) diff --git a/internlm/moe/experts.py b/internlm/moe/experts.py index bf34666..15e5289 100644 --- a/internlm/moe/experts.py +++ b/internlm/moe/experts.py @@ -21,7 +21,7 @@ class Experts(torch.nn.Module): Local Experts. """ - def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1): + def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None): super().__init__() # TODO: We can not deepcopy FeedForward since it contains a process_group in submodules @@ -38,6 +38,7 @@ class Experts(torch.nn.Module): # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) for _, param in expert.named_parameters(): param.all_reduce = False + param.group_name = expert_group_name def forward(self, inputs): chunks = inputs.chunk(self.num_local_experts, dim=1) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index dd3476a..d59316c 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -166,6 +166,10 @@ class HybridZeroOptimizer(BaseOptimizer): # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): + if "moe" in param_group.keys() and param_group["moe"]: + print("true", flush=True) + continue + group_params = param_group["params"] # add the fp16 params to fp16_param_groups for bookkeeping @@ -512,7 +516,6 @@ class HybridZeroOptimizer(BaseOptimizer): # all_groups_norm_old = all_groups_norm # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce pg = gpc.get_group(ParallelMode.DATA) - print(type(norm_groups)) scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) scaled_norm_tensor = torch.tensor( scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float diff --git a/train.py b/train.py index 5de592b..39fa942 100644 --- a/train.py +++ b/train.py @@ -30,7 +30,7 @@ from internlm.data.packed_dataset import ( from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.model.loss import FlashGPTLMLoss from internlm.model.metrics import AccPerplex -from internlm.model.moe import has_moe_layers +from internlm.model.moe import create_moe_param_groups, has_moe_layers from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler @@ -300,9 +300,14 @@ def initialize_optimizer(model: nn.Module): Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler). """ + adam_cfg = gpc.config.adam + if gpc.config.model.num_experts > 1: + params = create_moe_param_groups(model, adam_cfg.weight_decay) + else: + params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}] naive_optimizer = torch.optim.AdamW( - params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}], + params=params, lr=adam_cfg.lr, betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), eps=adam_cfg.adam_eps,