diff --git a/configs/7B_MoE8_sft.py b/configs/7B_MoE8_sft.py new file mode 100644 index 0000000..b4b6b6c --- /dev/null +++ b/configs/7B_MoE8_sft.py @@ -0,0 +1,170 @@ +JOB_NAME = "7b_moe_train" +DO_ALERT = False + +SEQ_LEN = 2048 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +MLP_RATIO = 8 / 3 +NUM_LAYER = 32 +VOCAB_SIZE = 103168 + +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), + load_ckpt_folder="local:llm_ckpts/", + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = "/path/to/dataset" +VALID_FOLDER = "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=2, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=50, + pack_sample_into_one=False, + total_steps=50000, + skip_batches="", + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + # train_folder=TRAIN_FOLDER, + # valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=10, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=True, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +model = dict( + checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + num_attention_heads=NUM_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + num_experts=8, + moe_use_residual=False, + moe_gate_k=2, +) +""" +zero1 parallel: + 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. + 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +tensor parallel: tensor parallel size, usually the number of GPUs per node. +""" +parallel = dict( + zero1=-1, + tensor=8, + pipeline=dict(size=1, interleaved_overlap=True), + sequence_parallel=False, +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + ), +) + +model_type = "INTERNLM_MoE" diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 651817f..3de3d45 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -107,7 +107,10 @@ class NonPipelineScheduler(BaseScheduler): with conditional_context(torch.no_grad(), enable=forward_only): self._call_hooks("before_forward", data) # moe_losses contains the loss of each layer - output, moe_losses = self._call_engine(engine, data) + if gpc.config.get("model_type") == "INTERNLM": + output = self._call_engine(engine, data) + if gpc.config.get("model_type") == "INTERNLM_MoE": + output, moe_losses = self._call_engine(engine, data) self._call_hooks("after_forward", output) self._call_hooks("post_helper_func", output, label) @@ -116,7 +119,11 @@ class NonPipelineScheduler(BaseScheduler): self._call_hooks("before_criterion", output, label) loss = self._call_engine_criterion(engine, output, label) self._call_hooks("after_criterion", loss) - moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff + moe_loss = ( + sum(moe_losses) * gpc.config.loss.moe_loss_coeff + if gpc.config.get("model_type") == "INTERNLM_MoE" + else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) + ) moe_loss /= scale_loss loss /= scale_loss loss += moe_loss @@ -199,4 +206,8 @@ class NonPipelineScheduler(BaseScheduler): if not return_output_label: outputs, labels = None, None - return outputs, labels, loss, moe_loss + # Compatible for old code + if gpc.config.get("model_type") == "INTERNLM": + return outputs, labels, loss + if gpc.config.get("model_type") == "INTERNLM_MoE": + return outputs, labels, loss, moe_loss diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 6d18e02..42e58e9 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -276,7 +276,10 @@ class PipelineScheduler(BaseScheduler): self._call_hooks("before_forward", data) # moe_losses contains the loss of each layer in current stage - output_obj, moe_losses = self._call_engine(engine.model, data) + if gpc.config.get("model_type") == "INTERNLM": + output_obj = self._call_engine(engine.model, data) + if gpc.config.get("model_type") == "INTERNLM_MoE": + output_obj, moe_losses = self._call_engine(engine.model, data) self._call_hooks("after_forward", output_obj) if gpc.is_last_rank(ParallelMode.PIPELINE): @@ -292,7 +295,11 @@ class PipelineScheduler(BaseScheduler): accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced - moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff + moe_loss = ( + sum(moe_losses) * gpc.config.loss.moe_loss_coeff + if gpc.config.get("model_type") == "INTERNLM_MoE" + else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) + ) moe_loss /= self.num_microbatches accum_moe_loss.add_(moe_loss.detach()) @@ -658,9 +665,19 @@ class PipelineScheduler(BaseScheduler): self.load_batch(engine, data_iter) if forward_only: - return self._forward_only_step(engine, return_loss, return_output_label) + output, label, accum_loss, accum_moe_loss = self._forward_only_step( + engine, return_loss, return_output_label + ) else: - return self._forward_backward_step(engine, return_loss, return_output_label) + output, label, accum_loss, accum_moe_loss = self._forward_backward_step( + engine, return_loss, return_output_label + ) + + # Compatible for old code + if gpc.config.get("model_type") == "INTERNLM": + return output, label, accum_loss + if gpc.config.get("model_type") == "INTERNLM_MoE": + return output, label, accum_loss, accum_moe_loss class InterleavedPipelineScheduler(PipelineScheduler): @@ -799,7 +816,10 @@ class InterleavedPipelineScheduler(PipelineScheduler): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data) + if gpc.config.get("model_type") == "INTERNLM": + output_obj = self._call_engine(engine.model[chunk_id], data) + if gpc.config.get("model_type") == "INTERNLM_MoE": + output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data) # Convert output_obj to fp32 when last model chunk of last stage if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel): output_obj = engine.model[chunk_id].convert_to_fp32(output_obj) @@ -819,7 +839,11 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced - moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff + moe_loss = ( + sum(moe_losses) * gpc.config.loss.moe_loss_coeff + if gpc.config.get("model_type") == "INTERNLM_MoE" + else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) + ) moe_loss /= self.num_microbatches if self._accum_moe_loss is not None: @@ -1354,4 +1378,8 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._clear_state() - return output, label, accum_loss, accum_moe_loss + # Compatible for old code + if gpc.config.get("model_type") == "INTERNLM": + return output, label, accum_loss + if gpc.config.get("model_type") == "INTERNLM_MoE": + return output, label, accum_loss, accum_moe_loss diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 705b17a..19be672 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -205,5 +205,4 @@ class Trainer: Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss). """ - output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) - return output, label, loss, moe_loss + return self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e7de61e..3896ede 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -253,8 +253,14 @@ def args_sanity_check(): # process the model config if "use_flash_attn" not in gpc.config.model: gpc.config.model._add_item("use_flash_attn", True) - if "num_experts" not in model: - model._add_item("num_experts", 1) + + if gpc.config.get("model_type") == "INTERNLM_MoE": + if "num_experts" not in model: + model._add_item("num_experts", 1) + if "moe_use_residual" not in model: + model._add_item("moe_use_residual", False) + if "moe_gate_k" not in model: + model._add_item("moe_gate_k", 2) # process the parallel config if "sequence_parallel" not in gpc.config.parallel: diff --git a/internlm/model/__init__.py b/internlm/model/__init__.py index b0fe77d..1bf7a86 100644 --- a/internlm/model/__init__.py +++ b/internlm/model/__init__.py @@ -5,6 +5,7 @@ from .embedding import Embedding1D, RotaryEmbedding from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear from .metrics import AccPerplex from .modeling_internlm import build_model_with_cfg +from .modeling_moe import build_model_with_moe_cfg from .multi_head_attention import MHA from .utils import gather_forward_split_backward @@ -18,4 +19,5 @@ __all__ = [ "MHA", "gather_forward_split_backward", "build_model_with_cfg", + "build_model_with_moe_cfg", ] diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index d726fbd..2856a78 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -18,7 +18,6 @@ from internlm.model.linear import ( RewardModelLinear, ScaleColumnParallelLinear, ) -from internlm.model.moe import MoE from internlm.model.multi_head_attention import MHA from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm from internlm.solver.pipeline_utils import partition_uniform @@ -51,17 +50,6 @@ class PackedFlashBaseLayer1D(nn.Module): device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. use_flash_attn (bool): Whether use flash-attn. True by default. - num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. - moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. - moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. - moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. - moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. - moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to - infinite capacity). - moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. - moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE - (https://arxiv.org/abs/2201.05596) layer. """ def __init__( @@ -84,15 +72,6 @@ class PackedFlashBaseLayer1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - num_experts: int = 1, - moe_gate_k: int = 1, - moe_capacity_factor: float = 1.0, - moe_eval_capacity_factor: float = 1.0, - moe_min_capacity: int = 4, - moe_noisy_gate_policy: str = None, - moe_drop_tokens: bool = True, - moe_use_rts: bool = True, - moe_use_residual: bool = False, ): super().__init__() self.checkpoint = checkpoint @@ -127,77 +106,41 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - for param in self.norm1.parameters(): - param.is_norm = True - for param in self.norm2.parameters(): - param.is_norm = True - - self.num_experts = num_experts - self.moe_gate_k = moe_gate_k - self.moe_capacity_factor = moe_capacity_factor - self.moe_eval_capacity_factor = moe_eval_capacity_factor - self.moe_min_capacity = moe_min_capacity - self.moe_noisy_gate_policy = moe_noisy_gate_policy - self.moe_drop_tokens = moe_drop_tokens - self.moe_use_rts = moe_use_rts - self.moe_use_residual = moe_use_residual - ep_size = gpc.get_world_size(ParallelMode.EXPERT) - if num_experts <= 1: # dense, not MoE - if use_swiglu: - self.mlp = FeedForward( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - process_group=gpc.get_group(ParallelMode.TENSOR), - bias=False, - device=device, - dtype=dtype, - ) - else: - self.mlp = ParallelFusedMLP( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - activation="gelu_approx", - process_group=gpc.get_group(ParallelMode.TENSOR), - bias1=False, - bias2=False, - sequence_parallel=gpc.config.model.sequence_parallel, - checkpoint_lvl=0, - heuristic="auto", - device=device, - dtype=dtype, - ) - for _, param in self.mlp.named_parameters(): - if gpc.get_world_size(ParallelMode.TENSOR) > 1: - setattr(param, IS_TENSOR_PARALLEL, True) - else: - # replace mlp by MoE module. The expert in MoE is a FeedForward module. - self.mlp = MoE( - hidden_size=hidden_size, - num_experts=num_experts, - ep_size=ep_size, - k=moe_gate_k, - capacity_factor=moe_capacity_factor, - eval_capacity_factor=moe_eval_capacity_factor, - min_capacity=moe_min_capacity, - noisy_gate_policy=moe_noisy_gate_policy, - drop_tokens=moe_drop_tokens, - use_rts=moe_use_rts, - use_residual=moe_use_residual, + if use_swiglu: + self.mlp = FeedForward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, device=device, dtype=dtype, ) - for _, param in self.mlp.moe_layer.experts.named_parameters(): - if gpc.get_world_size(ParallelMode.TENSOR) > 1: - setattr(param, IS_TENSOR_PARALLEL, True) + else: + self.mlp = ParallelFusedMLP( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + activation="gelu_approx", + process_group=gpc.get_group(ParallelMode.TENSOR), + bias1=False, + bias2=False, + sequence_parallel=gpc.config.parallel.sequence_parallel, + checkpoint_lvl=0, + heuristic="auto", + device=device, + dtype=dtype, + ) + for _, param in self.mlp.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm self.return_residual = False - self.reset_parameters() # TODO: check this should be changed when moe is added + self.reset_parameters() def reset_parameters(self): with torch.no_grad(): @@ -229,7 +172,7 @@ class PackedFlashBaseLayer1D(nn.Module): if self.checkpoint and self.training: return activation_checkpoint( self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen - ) # TODO: check whether this will be affected by moe + ) else: return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) @@ -279,14 +222,9 @@ class PackedFlashBaseLayer1D(nn.Module): if self.residual_in_fp32: residual = residual.to(torch.float32) - # MLP. - moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) - if self.num_experts <= 1: # dense mlp output - hidden_states = self.mlp(hidden_states) - else: # MoE output - hidden_states, moe_loss, _ = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states) - return hidden_states + residual, moe_loss + return hidden_states + residual class PackedFlashInternLm1D(nn.Module): @@ -316,17 +254,7 @@ class PackedFlashInternLm1D(nn.Module): residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. use_flash_attn (bool): Whether to use flash-attn. True by default. - num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. - moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. - moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. - moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. - moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. - moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent - to infinite capacity). - moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. - moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE - (https://arxiv.org/abs/2201.05596) layer. + """ def __init__( @@ -357,15 +285,6 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - num_experts: bool = 1, - moe_gate_k: int = 1, - moe_capacity_factor: float = 1.0, - moe_eval_capacity_factor: float = 1.0, - moe_min_capacity: int = 4, - moe_noisy_gate_policy: str = None, - moe_drop_tokens: bool = True, - moe_use_rts: bool = True, - moe_use_residual: bool = False, ): super().__init__() @@ -415,15 +334,6 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, - num_experts=num_experts, - moe_gate_k=moe_gate_k, - moe_capacity_factor=moe_capacity_factor, - moe_eval_capacity_factor=moe_eval_capacity_factor, - moe_min_capacity=moe_min_capacity, - moe_noisy_gate_policy=moe_noisy_gate_policy, - moe_drop_tokens=moe_drop_tokens, - moe_use_rts=moe_use_rts, - moe_use_residual=moe_use_residual, ) for lid in range(num_layers) ] @@ -450,8 +360,7 @@ class PackedFlashInternLm1D(nn.Module): def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): # attention_mask: compute attention on the places where the value is 1 - # old condition may fail when use shared embedding - if gpc.is_pipeline_first_stage(): + if hasattr(self, "embedding"): hidden_states = self.embedding(input_ids) if self.embed_grad_scale != 1: hidden_states = ( @@ -472,16 +381,14 @@ class PackedFlashInternLm1D(nn.Module): indexes = indexes[0] max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None - moe_losses = [] for _, block in enumerate(self.blocks): - hidden_states, mos_loss = block( + hidden_states = block( hidden_states, cu_seqlens=cu_seqlens, indexes=indexes, inference_params=inference_params, max_seqlen=max_seqlen, ) - moe_losses.append(mos_loss) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) @@ -490,7 +397,7 @@ class PackedFlashInternLm1D(nn.Module): if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) - return hidden_states, moe_losses + return hidden_states def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): @@ -558,15 +465,6 @@ def build_model_with_cfg( use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - num_experts: int = 1, - moe_gate_k: int = 1, - moe_capacity_factor: float = 1.0, - moe_eval_capacity_factor: float = 1.0, - moe_min_capacity: int = 4, - moe_noisy_gate_policy: str = None, - moe_drop_tokens: bool = True, - moe_use_rts: bool = True, - moe_use_residual: bool = False, ): """ Build model with config. @@ -597,17 +495,7 @@ def build_model_with_cfg( use_scaled_init (bool): Whether to use scaled init. True by default. use_swiglu (bool): Whether to use swiglu. True by default. use_flash_attn (bool): Whether to use flash-attn. True by default. - num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. - moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. - moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. - moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. - moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. - moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent - to infinite capacity). - moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. - moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE - (https://arxiv.org/abs/2201.05596) layer. + """ cfg = dict( @@ -632,15 +520,6 @@ def build_model_with_cfg( use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, - num_experts=num_experts, - moe_gate_k=moe_gate_k, - moe_capacity_factor=moe_capacity_factor, - moe_eval_capacity_factor=moe_eval_capacity_factor, - moe_min_capacity=moe_min_capacity, - moe_noisy_gate_policy=moe_noisy_gate_policy, - moe_drop_tokens=moe_drop_tokens, - moe_use_rts=moe_use_rts, - moe_use_residual=moe_use_residual, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py new file mode 100644 index 0000000..d4539a0 --- /dev/null +++ b/internlm/model/modeling_moe.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Optional + +import torch +from flash_attn.modules.embedding import ParallelGPT2Embeddings +from flash_attn.modules.mlp import ParallelFusedMLP +from torch import nn + +from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal +from internlm.model.embedding import Embedding1D +from internlm.model.linear import ( + FeedForward, + RewardModelLinear, + ScaleColumnParallelLinear, +) +from internlm.model.moe import MoE +from internlm.model.multi_head_attention import MHA +from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm +from internlm.solver.pipeline_utils import partition_uniform +from internlm.utils.checkpoint import activation_checkpoint +from internlm.utils.common import filter_kwargs +from internlm.utils.logger import get_logger +from internlm.utils.registry import MODEL_INITIALIZER + +MODEL_TYPE = "INTERNLM_MoE" + +logger = get_logger(__file__) +RMSNorm = try_import_RMSNorm() + + +class PackedFlashBaseLayer1D(nn.Module): + """ + 1D Packed Flash Base Layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. + use_flash_attn (bool): Whether use flash-attn. True by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to + infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE + (https://arxiv.org/abs/2201.05596) layer. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + norm_type: str = "rmsnorm", + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + num_experts: int = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.use_flash_attn = use_flash_attn + + head_dim = hidden_size // num_attention_heads + self.mixer = MHA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + process_group=gpc.get_group(ParallelMode.TENSOR), + dropout=attn_drop_rate, + max_position_embeddings=max_position_embeddings, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + use_flash_attn=use_flash_attn, + device=device, + dtype=dtype, + ) + + self.dropout1 = nn.Dropout(drop_rate) + if norm_type == "rmsnorm": + self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) + self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) + else: + self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + + for param in self.norm1.parameters(): + param.is_norm = True + for param in self.norm2.parameters(): + param.is_norm = True + + self.num_experts = num_experts + self.moe_gate_k = moe_gate_k + self.moe_capacity_factor = moe_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.moe_min_capacity = moe_min_capacity + self.moe_noisy_gate_policy = moe_noisy_gate_policy + self.moe_drop_tokens = moe_drop_tokens + self.moe_use_rts = moe_use_rts + self.moe_use_residual = moe_use_residual + ep_size = gpc.get_world_size(ParallelMode.EXPERT) + if num_experts <= 1: # dense, not MoE + if use_swiglu: + self.mlp = FeedForward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + ) + else: + self.mlp = ParallelFusedMLP( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + activation="gelu_approx", + process_group=gpc.get_group(ParallelMode.TENSOR), + bias1=False, + bias2=False, + sequence_parallel=gpc.config.model.sequence_parallel, + checkpoint_lvl=0, + heuristic="auto", + device=device, + dtype=dtype, + ) + for _, param in self.mlp.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + else: + # replace mlp by MoE module. The expert in MoE is a FeedForward module. + self.mlp = MoE( + hidden_size=hidden_size, + num_experts=num_experts, + ep_size=ep_size, + k=moe_gate_k, + capacity_factor=moe_capacity_factor, + eval_capacity_factor=moe_eval_capacity_factor, + min_capacity=moe_min_capacity, + noisy_gate_policy=moe_noisy_gate_policy, + drop_tokens=moe_drop_tokens, + use_rts=moe_use_rts, + use_residual=moe_use_residual, + device=device, + dtype=dtype, + ) + for _, param in self.mlp.moe_layer.experts.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + + self.dropout2 = nn.Dropout(drop_rate) + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + self.reset_parameters() # TODO: check this should be changed when moe is added + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.mixer.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "Wqkv" in name: + normal_(std=0.006)(param.data) + elif self.use_scaled_init: + scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) + else: + normal_(std=0.0015)(param.data) + + for name, param in self.mlp.named_parameters(): + if param.ndim == 1 and "bias" in name: + param.data.zero_() + elif self.use_swiglu: + if self.use_scaled_init and "w2" in name: + scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) + else: + normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) + else: + normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) + + def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + if self.checkpoint and self.training: + return activation_checkpoint( + self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen + ) # TODO: check whether this will be affected by moe + else: + return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) + + def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + mixer_kwargs = { + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "indexes": indexes, + "inference_params": inference_params, + } + + def _dropout_and_norm_attn(_hidden_states): + _dropped = self.dropout1(_hidden_states) + _residual = _dropped + _hidden_states = self.norm1(_residual.float()) + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_attn(hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer(hidden_states, **mixer_kwargs) + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.norm2(_residual.float()) + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + # MLP. + moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) + if self.num_experts <= 1: # dense mlp output + hidden_states = self.mlp(hidden_states) + else: # MoE output + hidden_states, moe_loss, _ = self.mlp(hidden_states) + + return hidden_states + residual, moe_loss + + +class PackedFlashInternLm1D(nn.Module): + """ + 1D Packed Flash InternLm. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 0.0 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. + True by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent + to infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE + (https://arxiv.org/abs/2201.05596) layer. + """ + + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + checkpoint: float = 0.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_split_hidden: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + device: Optional[torch.device] = None, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + num_experts: bool = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, + ): + super().__init__() + + checkpoint_layer_num = int(num_layers * checkpoint) + + if is_reward: + head_cls = RewardModelLinear + else: + head_cls = ScaleColumnParallelLinear + if first: + if embed_split_hidden: + self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + else: + self.embedding = ParallelGPT2Embeddings( + embed_dim=hidden_size, + vocab_size=vocab_size, + max_position_embeddings=-1, + process_group=gpc.get_group(ParallelMode.TENSOR), + padding_idx=None, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + for _, param in self.embedding.named_parameters(): + normal_(std=0.0052)(param) + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + self.embed_grad_scale = embed_grad_scale + self.blocks = nn.ModuleList( + [ + PackedFlashBaseLayer1D( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + max_position_embeddings=max_position_embeddings, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + residual_in_fp32=residual_in_fp32, + device=device, + norm_type=norm_type, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, + num_experts=num_experts, + moe_gate_k=moe_gate_k, + moe_capacity_factor=moe_capacity_factor, + moe_eval_capacity_factor=moe_eval_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_noisy_gate_policy=moe_noisy_gate_policy, + moe_drop_tokens=moe_drop_tokens, + moe_use_rts=moe_use_rts, + moe_use_residual=moe_use_residual, + ) + for lid in range(num_layers) + ] + ) + if last: + if norm_type == "rmsnorm": + self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + else: + self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.head = head_cls( + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + weight_scale=embed_grad_scale, + ) + for _, param in self.head.named_parameters(): + normal_(std=0.0052)(param) + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + self.parallel_output = parallel_output + + def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + # attention_mask: compute attention on the places where the value is 1 + # old condition may fail when use shared embedding + if gpc.is_pipeline_first_stage(): + hidden_states = self.embedding(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + if isinstance(cu_seqlens, list): + assert len(cu_seqlens) == 1 + cu_seqlens = cu_seqlens[0].to(hidden_states.device) + + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.squeeze(0) + hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state, + # the batch dimension with a size of 1 should be directly squeezed off. + + if indexes is not None: + assert len(indexes) == 1 + # The indexes are used to indicate the actual position IDs of each token in the packed input. + indexes = indexes[0] + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + moe_losses = [] + for _, block in enumerate(self.blocks): + hidden_states, mos_loss = block( + hidden_states, + cu_seqlens=cu_seqlens, + indexes=indexes, + inference_params=inference_params, + max_seqlen=max_seqlen, + ) + moe_losses.append(mos_loss) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states.float()) + if hasattr(self, "head"): + hidden_states = self.head(hidden_states) + + if not self.parallel_output: + hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + return hidden_states, moe_losses + + +def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): + """ + build generic model 1d + + Args: + num_layers (int): The number of layer. + num_chunks (int): The number of partitions in pipeline parallel. + device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. + + """ + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) + parts = all_parts[pipeline_rank] + if gpc.is_rank_for_log(): + logger.info(f"The layer sharding is {all_parts}.") + + models = [] + + for start, end in parts: + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + # If there is no content in the final layer, assign the last layer. + kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 + kwargs["device"] = device + kwargs["start_layer_idx"] = start + chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device) + + models.append(chunk) + torch.distributed.barrier() + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + + return model + + +@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) +def build_model_with_moe_cfg( + num_chunks=1, + checkpoint=0.0, + dtype=torch.float, + embed_split_hidden=False, + num_layers=48, + hidden_size=2048, + vocab_size=50304, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + max_position_embeddings=2048, + mlp_ratio=4.0, + residual_in_fp32=False, + use_dynamic_ntk_rope=False, + norm_type="rmsnorm", + drop_rate=0, + attn_drop_rate=0, + apply_post_layer_norm=False, # pylint: disable=W0613 + layer_norm_epsilon=1e-5, + is_reward=False, + dropout_selective_checkpoint=True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + num_experts: int = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, +): + """ + Build model with config. + + Args: + num_chunks (int): The number of partitions in pipeline parallel. 1 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. + dtype (torch.dtype): The type of data. torch.float by default. + embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. + False by default. + num_layers (int): The number of layer. 48 by default. + hidden_size (int): The size of hidden state. 2048 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + num_attention_heads (int): The number of attention head. 32 by default. + mlp_ratio (int): The ratio of MLP layers. 4.0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily + because this parameter requires inconsistent data types to be passed between pipelines, + which requires significant modifications to internlm. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + drop_rate (float): The dropout rate of input hidden state. 0 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + is_reward (bool): Whether to use reward model. False by default. + dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. + use_scaled_init (bool): Whether to use scaled init. True by default. + use_swiglu (bool): Whether to use swiglu. True by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent + to infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE + (https://arxiv.org/abs/2201.05596) layer. + """ + + cfg = dict( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + vocab_size=vocab_size, + embed_grad_scale=embed_grad_scale, + parallel_output=parallel_output, + mlp_ratio=mlp_ratio, + residual_in_fp32=residual_in_fp32, + max_position_embeddings=max_position_embeddings, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + norm_type=norm_type, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + layer_norm_epsilon=layer_norm_epsilon, + is_reward=is_reward, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, + num_experts=num_experts, + moe_gate_k=moe_gate_k, + moe_capacity_factor=moe_capacity_factor, + moe_eval_capacity_factor=moe_eval_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_noisy_gate_policy=moe_noisy_gate_policy, + moe_drop_tokens=moe_drop_tokens, + moe_use_rts=moe_use_rts, + moe_use_residual=moe_use_residual, + ) + + return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 4181f20..1f77ef2 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -111,7 +111,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): adam_cfg = gpc.config.adam # split the moe parameters into different groups - if gpc.config.model.num_experts > 1: + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: params = create_param_groups(model, adam_cfg.weight_decay) else: params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}] @@ -435,8 +435,7 @@ def record_current_batch_training_metrics( infos = { "tflops": tflops, "step": batch_count, - "loss": loss.item() - moe_loss.item(), - "moe_loss": moe_loss.item(), + "loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(), "tgs (tokens/gpu/second)": tgs_origin, "tgs/last_tgs_1": last_tgs_1, "tgs/tgs_all": tgs_all, @@ -448,6 +447,8 @@ def record_current_batch_training_metrics( "loss_scale": scaler, "grad_norm": grad_norm, } + if moe_loss is not None: + infos["moe_loss"] = moe_loss.item() infos["micro_num"] = len(batch[1]) infos["num_consumed_tokens"] = train_state.num_consumed_tokens @@ -481,13 +482,14 @@ def record_current_batch_training_metrics( "step": batch_count, "lr": lr, "num_consumed_tokens": train_state.num_consumed_tokens, - "loss": loss.item() - moe_loss.item(), + "loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(), "flops": tflops, "tgs": last_tgs_1, "acc": acc_perplex["acc"], "perplexity": acc_perplex["perplexity"], "fwd_bwd_time": fwd_bwd_time, } + panel_metrics["moe_loss"] = moe_loss.item() for norm_key, norm_value in grad_norm.items(): panel_metrics[norm_key] = norm_value diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 6128249..13d4468 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -97,6 +97,7 @@ def evaluate_on_val_dls( disable=not verbose, leave=False, ): + moe_loss = None with torch.inference_mode(): if gpc.is_using_pp(): total_val_bsz = len(batch[1]) @@ -112,9 +113,15 @@ def evaluate_on_val_dls( tensor_shape=tensor_shape, metric_hook_list=[val_sche_metric_hook], ): - _, _, loss, moe_loss = trainer.execute_schedule( - batch, forward_only=True, return_loss=True, return_output_label=False - ) + # Compatible for old code + if gpc.config.get("model_type") == "INTERNLM": + _, _, loss = trainer.execute_schedule( + batch, forward_only=True, return_loss=True, return_output_label=False + ) + elif gpc.config.get("model_type") == "INTERNLM_MoE": + _, _, loss, moe_loss = trainer.execute_schedule( + batch, forward_only=True, return_loss=True, return_output_label=False + ) else: total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 @@ -126,11 +133,16 @@ def evaluate_on_val_dls( grad_accum_batch_size=grad_accum_batch_size, metric_hook_list=[val_sche_metric_hook], ): - _, _, loss, moe_loss = trainer.execute_schedule( - batch, forward_only=True, return_loss=True, return_output_label=False - ) + if gpc.config.get("model_type") == "INTERNLM": + _, _, loss = trainer.execute_schedule( + batch, forward_only=True, return_loss=True, return_output_label=False + ) + elif gpc.config.get("model_type") == "INTERNLM_MoE": + _, _, loss, moe_loss = trainer.execute_schedule( + batch, forward_only=True, return_loss=True, return_output_label=False + ) if verbose: - val_loss += loss.item() - moe_loss.item() + val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item() assert val_idx != -1 dist.barrier() diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 864fc24..caba6a9 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -186,11 +186,20 @@ def train( # do forward and backward timer("fwd-bwd").start() - _, _, loss, moe_loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False) + # Compatible for old code + moe_loss = None + if gpc.config.get("model_type") == "INTERNLM": + _, _, loss = trainer.execute_schedule( + batch, forward_only=False, return_loss=True, return_output_label=False + ) + elif gpc.config.get("model_type") == "INTERNLM_MoE": + _, _, loss, moe_loss = trainer.execute_schedule( + batch, forward_only=False, return_loss=True, return_output_label=False + ) if gpc.is_rank_for_log(): assert loss is not None and not math.isnan(loss.item()) global cur_loss_list - cur_loss_list.append(loss.item() - moe_loss.item()) + cur_loss_list.append((loss.item() - moe_loss.item() if moe_loss is not None else loss.item())) timer("fwd-bwd").stop() # update parameters, and returns (success_update, grad_norm) diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 80cb353..379a3e0 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -59,6 +59,7 @@ init_config = Config( def init_naive_model(): # let MODEL_INITIALIZER to work import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import + import internlm.model.modeling_moe # noqa # pylint: disable=unused-import from internlm.core.naive_amp import NaiveAMPModel from internlm.utils.registry import MODEL_INITIALIZER diff --git a/train.py b/train.py index 4d30b90..7a30f6c 100644 --- a/train.py +++ b/train.py @@ -219,12 +219,21 @@ def main(args): # do forward and backward timer("fwd-bwd").start() - _, _, loss, moe_loss = trainer.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) + moe_loss = None + if gpc.config.get("model_type") == "INTERNLM": + _, _, loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + if gpc.config.get("model_type") == "INTERNLM_MoE": + _, _, loss, moe_loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) timer("fwd-bwd").stop() # update parameters, and returns (success_update, grad_norm)