diff --git a/internlm/model/__init__.py b/internlm/model/__init__.py index a4efc03..d52e473 100644 --- a/internlm/model/__init__.py +++ b/internlm/model/__init__.py @@ -6,6 +6,9 @@ from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear from .metrics import AccPerplex from .modeling_internlm import build_model_with_cfg from .modeling_llama import build_model_with_cfg as build_model_with_llama_cfg +from .modeling_llama_moe import ( + build_model_with_moe_cfg as build_model_with_llama_moe_cfg, +) from .modeling_moe import build_model_with_moe_cfg from .moe import MoE from .multi_head_attention import MHA @@ -24,4 +27,5 @@ __all__ = [ "build_model_with_cfg", "build_model_with_moe_cfg", "build_model_with_llama_cfg", + "build_model_with_llama_moe_cfg", ] diff --git a/internlm/model/modeling_llama_moe.py b/internlm/model/modeling_llama_moe.py new file mode 100644 index 0000000..2740e98 --- /dev/null +++ b/internlm/model/modeling_llama_moe.py @@ -0,0 +1,1255 @@ +# Copyright (c) InternLM. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.naive_amp import set_fp32_attr_to_module +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, + scaled_init_method_uniform, + uniform_, +) +from internlm.model.embedding import Embedding1D, RotaryEmbedding +from internlm.model.linear import ( + ColumnParallelLinearTorch, + FeedForward, + RewardModelLinear, + RowParallelLinearTorch, + ScaleColumnParallelLinear, +) +from internlm.model.moe import MoE +from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm +from internlm.solver.pipeline_utils import partition_uniform +from internlm.utils.checkpoint import activation_checkpoint +from internlm.utils.common import filter_kwargs +from internlm.utils.logger import get_logger +from internlm.utils.registry import MODEL_INITIALIZER + +try: + from flash_attn import flash_attn_varlen_kvpacked_func + from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc + from flash_attn.modules.embedding import ParallelGPT2Embeddings + from flash_attn.modules.mha import ( + CrossAttention, + FlashCrossAttention, + FlashSelfAttention, + SelfAttention, + _update_kv_cache, + ) + from flash_attn.modules.mlp import ParallelFusedMLP + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + pass + +MODEL_TYPE = "LLAMA_MoE" + +logger = get_logger(__file__) +RMSNorm = try_import_RMSNorm() + + +class MHA(nn.Module): + """ + Multi-head self-attention and cross-attention. + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + num_kv_heads (int): The number of kv attention heads. + process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. + bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and + output projection. True by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used. + True by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + rot_embed_HF_impl: rotary embedding hf implementation. False by default. + + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + process_group: Optional[torch.distributed.ProcessGroup], + bias: bool = True, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + rope_base: int = 10000, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + use_flash_attn: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + rot_embed_HF_impl: Optional[bool] = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" + + self.head_dim = self.embed_dim // num_heads + self.num_kv_heads = num_kv_heads + self.kv_dim = self.head_dim * num_kv_heads + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.dtype = dtype + + self.rot_embed_HF_impl = rot_embed_HF_impl + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + + if self.rotary_emb_dim > 0: + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device + ) + + # notice here should change bias=True + self.wq = ColumnParallelLinearTorch( + embed_dim, + embed_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.wk = ColumnParallelLinearTorch( + embed_dim, + self.kv_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.wv = ColumnParallelLinearTorch( + embed_dim, + self.kv_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + + self.inner_cross_attn_causal = causal + self.inner_cross_attn_softmax_scale = softmax_scale + self.inner_cross_attn_dropout = dropout + + # output projection always have the bias (for now) + self.wo = RowParallelLinearTorch( + embed_dim, + embed_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + # need to assign tp attribute so that internlm know it is tensor parallel module + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + for name in ["wo", "wq", "wk", "wv"]: + for param in getattr(self, name).parameters(): + setattr(param, IS_TENSOR_PARALLEL, True) + + def forward(self, x, seqlen=None, inference_params=None, **kwargs): + if kwargs.get("indexes", None) is not None: + return self._packed_forward(x=x, inference_params=inference_params, **kwargs) + else: + return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) + + def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + bsz, _, _ = x.shape + q, k, v = self.wq(x), self.wk(x), self.wv(x) + if seqlen is None: + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + else: + q = rearrange(q, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) + k = rearrange(k, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) + v = rearrange(v, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) + + if not self.rot_embed_HF_impl: + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + if inference_params is None: + if self.rotary_emb_dim > 0: + q = self.rotary_emb._single_eval_forward(q) + k = self.rotary_emb._single_eval_forward(k) + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + if self.dtype is torch.float32 and self.use_flash_attn: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = self.inner_cross_attn(q, kv).to(self.dtype) + else: + context = self.inner_cross_attn(q, kv) + + else: + assert self.rotary_emb_dim > 0 + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + empties = inference_params.attention_mask[..., -1].sum(dim=-1) + moved_q = q.clone() + moved_k = k.clone() + if inference_params.sequence_len_offset == 0: + for i in range(len(empties)): + if empties[i] != 0: + moved_q[i][: -empties[i]] = q[i][empties[i] :] + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_q = self.rotary_emb._single_eval_forward( + moved_q, seqlen_offset=inference_params.sequence_len_offset + ) + moved_k = self.rotary_emb._single_eval_forward( + moved_k, seqlen_offset=inference_params.sequence_len_offset + ) + for i in range(len(empties)): + if empties[i] != 0: + q[i][empties[i] :] = moved_q[i][: -empties[i]] + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + q[i] = moved_q[i] + k[i] = moved_k[i] + else: + q = q.squeeze(1) + k = k.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + k = self.rotary_emb._single_forward( + k, + inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) + - empties, + ).unsqueeze(1) + else: + raise NotImplementedError( + "You should make sure you are aware that you are changing the method of generating." + "According to your generation function instead of inference/seq_generator_module.py, " + "You may implement here for normal running." + ) + + kv = torch.stack([k, v], dim=2) + + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + if hasattr(inference_params, "window_size") and inference_params.window_size is not None: + if inference_params.window_size <= inference_params.sequence_len_offset: + assert kv.size(1) == 1, "update kv lenth more than 1" + inference_params.key_value_memory_dict[self.layer_idx][ + :, inference_params.keep_first : inference_params.window_size - 1, ... + ] = inference_params.key_value_memory_dict[self.layer_idx][ + :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... + ].clone() + inference_params.real_sequence_len_offset = inference_params.sequence_len_offset + inference_params.sequence_len_offset = inference_params.window_size - 1 + + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + + inference_params.sequence_len_offset = inference_params.real_sequence_len_offset + else: + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + else: + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + + # When using FP16, there is a high probability of NAN in the KV. + # Since NAN cannot be removed by multiplying with and 0, it needs + # to be removed manually here. + kv = torch.where(torch.isnan(kv), 0, kv) + + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + assert self.use_flash_attn is True + if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = inference_params.attention_mask[:, None, ...] + attn_mask = torch.logical_or( + torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask + ) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), + attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ) + cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) + max_seqlen_q = attn_mask4flsh.shape[-1] + max_seqlen_k = attn_mask4flsh.shape[-1] + total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) + total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( + -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + ) + if self.dtype is torch.float32: + if total_q.dtype not in [torch.float16, torch.bfloat16]: + total_q = total_q.to(torch.bfloat16) + if total_kv.dtype not in [torch.float16, torch.bfloat16]: + total_kv = total_kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + output = FlashAttnVarlenKVPackedFunc.apply( + total_q, + total_kv, + cu_seqlens, + cu_seqlens, + max_seqlen_q, + max_seqlen_k, + 0.0, + None, + True, + False, + ).to(self.dtype) + else: + output = FlashAttnVarlenKVPackedFunc.apply( + total_q, + total_kv, + cu_seqlens, + cu_seqlens, + max_seqlen_q, + max_seqlen_k, + 0.0, + None, + True, + False, + ) + + context = torch.zeros_like(q) + context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) + + else: + attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) + if hasattr(inference_params, "window_size") and inference_params.window_size is not None: + if inference_params.window_size <= inference_params.sequence_len_offset: + attn_mask = torch.concat( + [ + attn_mask[..., : inference_params.keep_first], + attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :], + ], + dim=-1, + ) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + expansion = q.size(2) // k.size(2) + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + else: + if self.dtype is torch.float32 and self.use_flash_attn: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = self.inner_cross_attn(q, kv, causal=True).to(self.dtype) + else: + context = self.inner_cross_attn(q, kv, causal=True) + if seqlen is None: + context = rearrange(context, "b s h d -> b s (h d)") + else: + context = rearrange(context, "b s h d -> (b s) (h d)") + out = self.wo(context) + return out + + def _packed_forward(self, x, inference_params=None, **kwargs): + """ + we delete seqlen=None for lint check, cause this arg is not used. + + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + assert self.use_flash_attn is True + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "t (h d) -> t h d", d=self.head_dim) + k = rearrange(k, "t (h d) -> t h d", d=self.head_dim) + v = rearrange(v, "t (h d) -> t h d", d=self.head_dim) + + # qkv shift + # the rotary embedding in flash attention module in performed by separating the front and back parts, while + # most of others are done by odd-even methods. + if not self.rot_embed_HF_impl: + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + + indexes = kwargs.pop("indexes") + q = self.rotary_emb._single_forward(q, indexes=indexes) + k = self.rotary_emb._single_forward(k, indexes=indexes) + + if inference_params is None: + kv = torch.concat([k.unsqueeze(1), v.unsqueeze(1)], dim=1) + if self.dtype is torch.float32: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=kwargs["cu_seqlens"], + cu_seqlens_k=kwargs["cu_seqlens"], + max_seqlen_q=kwargs["max_seqlen"], + max_seqlen_k=kwargs["max_seqlen"], + dropout_p=self.inner_cross_attn_dropout, + softmax_scale=self.inner_cross_attn_softmax_scale, + causal=self.inner_cross_attn_causal, + ).to(self.dtype) + else: + context = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=kwargs["cu_seqlens"], + cu_seqlens_k=kwargs["cu_seqlens"], + max_seqlen_q=kwargs["max_seqlen"], + max_seqlen_k=kwargs["max_seqlen"], + dropout_p=self.inner_cross_attn_dropout, + softmax_scale=self.inner_cross_attn_softmax_scale, + causal=self.inner_cross_attn_causal, + ) + else: + raise RuntimeError("Not support this right now") + + context = rearrange(context, "b h d -> b (h d)") # recover shape + out = self.wo(context) + return out + + +class PackedFlashLlamaLayer1D(nn.Module): + """ + 1D Packed Flash Llama Layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + num_kv_attention_heads (int): The number of kv attention heads. 12 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + apply_post_layer_norm (bool): Whether use post layer norm. False by default. + fused_dropout_add_ln (bool): Whether use fused dropout add ln. True by default. + no_bias (bool): Whether remove bias. False by default. + norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. + adapt_hf (bool): Whether adapt hf. False by default. + dropout_selective_checkpoint (bool): Whether use dropout selective checkpoint. True by default. + use_scaled_init (bool): Whether use scaled init. True by default. + use_swiglu (bool): Whether use swiglu. True by default. + use_flash_attn (bool): Whether use flash-attn. True by default. + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent + to infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE + (https://arxiv.org/abs/2201.05596) layer. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm: bool = False, + fused_dropout_add_ln: bool = True, + no_bias: bool = False, + norm_type: str = "rmsnorm", + adapt_hf: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + num_experts: int = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.use_flash_attn = use_flash_attn + self.prenorm = not apply_post_layer_norm + assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" + self.fused_dropout_add_ln = fused_dropout_add_ln + self.attn_wqkv_init_std = attn_wqkv_init_std + self.attn_other_init_std = attn_other_init_std + self.ffn_uplayer_init_std = ffn_uplayer_init_std + self.ffn_other_init_std = ffn_other_init_std + + head_dim = hidden_size // num_attention_heads + self.attention = MHA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_attention_heads, + process_group=gpc.get_group(ParallelMode.TENSOR), + dropout=attn_drop_rate, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + use_flash_attn=use_flash_attn, + device=device, + dtype=dtype, + rot_embed_HF_impl=adapt_hf, + bias=not no_bias, + rope_base=rope_base, + ) + + self.dropout1 = nn.Dropout(drop_rate) + if norm_type == "rmsnorm": + self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + else: + self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + set_fp32_attr_to_module(self.attention_norm) + set_fp32_attr_to_module(self.ffn_norm) + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed" + assert isinstance(self.attention_norm, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) + + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + self.num_experts = num_experts + self.moe_gate_k = moe_gate_k + self.moe_capacity_factor = moe_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.moe_min_capacity = moe_min_capacity + self.moe_noisy_gate_policy = moe_noisy_gate_policy + self.moe_drop_tokens = moe_drop_tokens + self.moe_use_rts = moe_use_rts + self.moe_use_residual = moe_use_residual + ep_size = gpc.get_world_size(ParallelMode.EXPERT) + if num_experts <= 1: # dense, not MoE + if use_swiglu: + self.feed_forward = FeedForward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + ) + else: + self.feed_forward = ParallelFusedMLP( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + activation="gelu_approx", + process_group=gpc.get_group(ParallelMode.TENSOR), + bias1=False, + bias2=False, + sequence_parallel=sequence_parallel, + checkpoint_lvl=0, + heuristic="auto", + device=device, + dtype=dtype, + ) + for _, param in self.feed_forward.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + else: + # replace mlp by MoE module. The expert in MoE is a FeedForward module. + self.feed_forward = MoE( + hidden_size=hidden_size, + num_experts=num_experts, + ep_size=ep_size, + k=moe_gate_k, + capacity_factor=moe_capacity_factor, + eval_capacity_factor=moe_eval_capacity_factor, + min_capacity=moe_min_capacity, + noisy_gate_policy=moe_noisy_gate_policy, + drop_tokens=moe_drop_tokens, + use_rts=moe_use_rts, + use_residual=moe_use_residual, + device=device, + dtype=dtype, + ) + for _, param in self.feed_forward.moe_layer.experts.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + set_fp32_attr_to_module(self.feed_forward.moe_layer.gate) + + for param in self.attention_norm.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + for param in self.ffn_norm.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + + self.dropout2 = nn.Dropout(drop_rate) + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + + if init_type == "normal": + self.init_func = normal_ + self.scaled_init_func = scaled_init_method_normal + else: + self.init_func = uniform_ + self.scaled_init_func = scaled_init_method_uniform + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + self.init_func(std=self.attn_wqkv_init_std)(param.data) + elif self.use_scaled_init: # wo + self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.attn_other_init_std)(param.data) + + for name, param in self.feed_forward.named_parameters(): + if self.use_swiglu: + if self.use_scaled_init and "w2" in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func( + std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std + )(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( + param.data + ) + + def forward( + self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None + ): + if self.checkpoint and self.training: + return activation_checkpoint( + self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen + ) + else: + return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen) + + def _forward( + self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + if self.prenorm: + + def _dropout_and_norm_attn(_residual, _hidden_states): + _dropped = self.dropout1(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + mixer_kwargs = { + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "indexes": indexes, + "inference_params": inference_params, + } + hidden_states = self.attention(hidden_states, **mixer_kwargs) + + if not isinstance(self.feed_forward, nn.Identity): + if not self.fused_dropout_add_ln: + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.ffn_norm(_residual.to(torch.float32)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint( + _dropout_and_norm_ffn, False, residual, hidden_states + ) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + # MLP. + moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) + if self.num_experts <= 1: # dense mlp output + hidden_states = self.feed_forward(hidden_states) + else: # MoE output + hidden_states, moe_loss, _ = self.feed_forward(hidden_states) + + return hidden_states + residual, moe_loss + else: + assert residual is None + mixer_kwargs = { + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "indexes": indexes, + "inference_params": inference_params, + } + mixer_out = self.attention(hidden_states, **mixer_kwargs) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( + dtype=self.attention_norm.weight.dtype + ) + if not isinstance(self.feed_forward, nn.Identity): + # MLP. + moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) + if self.num_experts <= 1: # dense mlp output + mlp_out = self.feed_forward(hidden_states) + else: # MoE output + mlp_out, moe_loss, _ = self.feed_forward(hidden_states) + if self.return_residual: # mlp out is actually a pair here + # NOTE: should not be here + mlp_out, hidden_states = mlp_out + hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.ffn_norm.weight.dtype + ) + return hidden_states, moe_loss + + +class PackedFlashLlama1D(nn.Module): + """ + 1D Packed Flash Llama. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + num_kv_attention_heads (int): The number of kv attention head. 12 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 1.0 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. + True by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + apply_post_layer_norm (bool): Whether use post layer norm. False by default. + no_bias (bool): Whether remove bias. False by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + adapt_hf (bool): Whether adapt hf. False by default. + is_reward (bool): Whether use is_reward. False by default. + dropout_selective_checkpoint (bool): Whether dropout selective checkpoint. True by default. + use_scaled_init (bool): Whether use scaled init. True by default. + use_swiglu (bool): Whether use swiglu. True by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent + to infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE + (https://arxiv.org/abs/2201.05596) layer. + """ + + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + vocab_size: int = 50304, + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + checkpoint_fraction: float = 1.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_split_hidden: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + device: Optional[torch.device] = None, + apply_post_layer_norm=False, + no_bias=False, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + adapt_hf: bool = False, + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + num_experts: bool = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, + ): + super().__init__() + + self.use_flash_attn = use_flash_attn + if checkpoint_fraction <= 0: + checkpoint = False + if not checkpoint: + checkpoint_fraction = 0 + checkpoint_layer_num = num_layers * checkpoint_fraction + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + if is_reward: + head_cls = RewardModelLinear + else: + head_cls = ScaleColumnParallelLinear + if first: + if embed_split_hidden: + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + else: + + self.tok_embeddings = ParallelGPT2Embeddings( + embed_dim=hidden_size, + vocab_size=vocab_size, + max_position_embeddings=-1, + process_group=gpc.get_group(ParallelMode.TENSOR), + padding_idx=None, + sequence_parallel=sequence_parallel, + device=device, + dtype=dtype, + ) + for _, param in self.tok_embeddings.named_parameters(): + if init_type == "normal": + normal_(std=embedding_init_std)(param) + else: + uniform_(std=embedding_init_std)(param) + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + self.embed_grad_scale = embed_grad_scale + self.layers = nn.ModuleList( + [ + PackedFlashLlamaLayer1D( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + residual_in_fp32=residual_in_fp32, + device=device, + apply_post_layer_norm=apply_post_layer_norm, + fused_dropout_add_ln=False, + no_bias=no_bias, + norm_type=norm_type, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, + adapt_hf=adapt_hf, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + init_type=init_type, + rope_base=rope_base, + num_experts=num_experts, + moe_gate_k=moe_gate_k, + moe_capacity_factor=moe_capacity_factor, + moe_eval_capacity_factor=moe_eval_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_noisy_gate_policy=moe_noisy_gate_policy, + moe_drop_tokens=moe_drop_tokens, + moe_use_rts=moe_use_rts, + moe_use_residual=moe_use_residual, + ) + for lid in range(num_layers) + ] + ) + + if last: + if not apply_post_layer_norm: + if norm_type == "rmsnorm": + self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + else: + self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + for param in self.norm.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + + self.output = head_cls( + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + weight_scale=embed_grad_scale, + ) + + for _, param in self.output.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) + + self.parallel_output = parallel_output + + def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + # attention_mask: compute attention on the places where the value is 1 + if hasattr(self, "tok_embeddings"): + hidden_states = self.tok_embeddings(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + if isinstance(cu_seqlens, list): + assert len(cu_seqlens) == 1 + cu_seqlens = cu_seqlens[0].to(hidden_states.device) + + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.squeeze(0) + hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state, + # the batch dimension with a size of 1 should be directly squeezed off. + + if indexes is not None: + assert len(indexes) == 1 + # The indexes are used to indicate the actual position IDs of each token in the packed input. + indexes = indexes[0] + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + moe_losses = [] + for _, block in enumerate(self.layers): + hidden_states, mos_loss = block( + hidden_states, + residual=None, + cu_seqlens=cu_seqlens, + indexes=indexes, + inference_params=inference_params, + max_seqlen=max_seqlen, + ) + moe_losses.append(mos_loss) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states.float()) + + extra_hidden_states_list = None + if hasattr(self, "output"): + hidden_states = self.output(hidden_states) + + if not self.parallel_output: + hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + if extra_hidden_states_list is not None: + extra_hidden_states_list = [ + gather_forward_split_backward(extra_hidden_states, ParallelMode.TENSOR, dim=-1) + for extra_hidden_states in extra_hidden_states_list # pylint: disable=E1133 + ] + + if extra_hidden_states_list is not None: + return (hidden_states, extra_hidden_states_list) + + return hidden_states, moe_losses + + +def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): + """ + build generic model 1d + + Args: + num_layers (int): The number of layer. + num_chunks (int): The number of partitions in pipeline parallel. + device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. + + """ + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) + parts = all_parts[pipeline_rank] + if gpc.is_rank_for_log(): + logger.info(f"The layer sharding is {all_parts}.") + + models = [] + kwargs["checkpoint_fraction"] = 1.0 + start_idx, end_idx = 0, 0 + for start, end in parts: + start_idx, end_idx = start, end + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + # If there is no content in the final layer, assign the last layer. + kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 + kwargs["device"] = device + kwargs["start_layer_idx"] = start + chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device) + + models.append(chunk) + torch.distributed.barrier() + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + setattr(model, "first_layer", start_idx) + setattr(model, "last_layer", end_idx) + return model + + +@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) +def build_model_with_moe_cfg( + num_chunks=1, + checkpoint=False, + dtype=torch.float, + embed_split_hidden=False, + num_layers=48, + hidden_size=2048, + vocab_size=50304, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=None, + mlp_ratio=4.0, + residual_in_fp32=False, + norm_type="rmsnorm", + adapt_hf=False, + drop_rate=0, + attn_drop_rate=0, + apply_post_layer_norm=False, # pylint: disable=W0613 + no_bias=False, + deepnorm=False, # pylint: disable=W0613 + layer_norm_epsilon=1e-5, + is_reward=False, + dropout_selective_checkpoint=True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + num_experts: int = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, +): + """ + Build model with config. + + Args: + num_chunks (int): The number of partitions in pipeline parallel. 1 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. + dtype (torch.dtype): The type of data. torch.float by default. + embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. + False by default. + num_layers (int): The number of layer. 48 by default. + hidden_size (int): The size of hidden state. 2048 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + num_attention_heads (int): The number of attention head. 32 by default. + mlp_ratio (int): The ratio of MLP layers. 4.0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily + because this parameter requires inconsistent data types to be passed between pipelines, + which requires significant modifications to internlm. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + drop_rate (float): The dropout rate of input hidden state. 0 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + is_reward (bool): Whether to use reward model. False by default. + dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. + use_scaled_init (bool): Whether to use scaled init. True by default. + use_swiglu (bool): Whether to use swiglu. True by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent + to infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE + (https://arxiv.org/abs/2201.05596) layer. + """ + + cfg = dict( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + vocab_size=vocab_size, + embed_grad_scale=embed_grad_scale, + parallel_output=parallel_output, + mlp_ratio=mlp_ratio, + apply_post_layer_norm=apply_post_layer_norm, + no_bias=no_bias, + residual_in_fp32=residual_in_fp32, + norm_type=norm_type, + adapt_hf=adapt_hf, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + layer_norm_epsilon=layer_norm_epsilon, + is_reward=is_reward, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, + embedding_init_std=embedding_init_std, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + out_head_init_std=out_head_init_std, + init_type=init_type, + rope_base=rope_base, + num_experts=num_experts, + moe_gate_k=moe_gate_k, + moe_capacity_factor=moe_capacity_factor, + moe_eval_capacity_factor=moe_eval_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_noisy_gate_policy=moe_noisy_gate_policy, + moe_drop_tokens=moe_drop_tokens, + moe_use_rts=moe_use_rts, + moe_use_residual=moe_use_residual, + ) + + return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/tools/transformers/mixtral2llamamoe.py b/tools/transformers/mixtral2llamamoe.py new file mode 100644 index 0000000..5d742d0 --- /dev/null +++ b/tools/transformers/mixtral2llamamoe.py @@ -0,0 +1,182 @@ +import argparse +import os + +import torch +from tqdm import tqdm +from transformers import AutoConfig + + +def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash): + hf_state = {} + print("Loading HF checkpoints...") + for filename in tqdm(os.listdir(src)): + if not filename.endswith(".bin"): + continue + hf_state.update(torch.load(os.path.join(src, filename))) + + print("Reverting HF checkpoints to InternLM...") + config = AutoConfig.from_pretrained(src, trust_remote_code=True) + + n_heads = config.num_attention_heads + try: + n_kv_heads = config.num_key_value_heads + except AttributeError: + n_kv_heads = n_heads + dim = config.hidden_size + + # n_heads_per_shard = n_heads // tp_size + # dims_per_head = dim // n_heads + + def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): + if adapt_hf: + return w + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # revert + states = [{} for _ in range(tp_size)] + moe_states = [ + [[{} for _ in range(tp_size)] for _ in range(config.num_experts)] for _ in range(config.num_hidden_layers) + ] + + # layers + for layer_i in tqdm(range(config.num_hidden_layers)): + # no-moe + for i in range(tp_size): + states[i][f"model.layers.{layer_i}.attention_norm.weight"] = hf_state[ + f"model.layers.{layer_i}.input_layernorm.weight" + ].clone() + states[i][f"model.layers.{layer_i}.ffn_norm.weight"] = hf_state[ + f"model.layers.{layer_i}.post_attention_layernorm.weight" + ].clone() + states[i][f"model.layers.{layer_i}.feed_forward.moe_layer.gate.wg.weight"] = hf_state[ + f"model.layers.{layer_i}.mlp.gate.weight" + ].clone() + + # mha + wqs = ( + permute(hf_state[f"model.layers.{layer_i}.self_attn.q_proj.weight"]) + # .view(-1, dims_per_head, dim) + .chunk(tp_size, 0) + ) + wks = ( + permute(hf_state[f"model.layers.{layer_i}.self_attn.k_proj.weight"], n_kv_heads, -1, dim) + # .view(-1, dims_per_head, dim) + .chunk(tp_size, 0) + ) + wvs = ( + hf_state[f"model.layers.{layer_i}.self_attn.v_proj.weight"] + # .view(-1, dims_per_head, dim) + .chunk(tp_size, 0) + ) + wos = hf_state[f"model.layers.{layer_i}.self_attn.o_proj.weight"].chunk(tp_size, 1) + for i in range(tp_size): + states[i][f"model.layers.{layer_i}.attention.wq.weight"] = wqs[i].reshape(-1, dim).clone() + states[i][f"model.layers.{layer_i}.attention.wk.weight"] = wks[i].reshape(-1, dim).clone() + states[i][f"model.layers.{layer_i}.attention.wv.weight"] = wvs[i].reshape(-1, dim).clone() + states[i][f"model.layers.{layer_i}.attention.wo.weight"] = wos[i].clone() + + # moe + for expert_id in range(config.num_experts): + w1s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w1.weight"].chunk(tp_size, 0) + w2s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w3.weight"].chunk(tp_size, 0) + w3s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w2.weight"].chunk(tp_size, 1) + for i in range(tp_size): + moe_states[layer_i][expert_id][i][ + f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w1.weight" + ] = w1s[i].clone() + moe_states[layer_i][expert_id][i][ + f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w2.weight" + ] = w2s[i].clone() + moe_states[layer_i][expert_id][i][ + f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight" + ] = w3s[i].clone() + + if embed_split_hidden: + embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1) + states[i]["model.tok_embeddings.weight"] = embeds[i].clone() + else: + embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0) + states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone() + + outputs = hf_state["lm_head.weight"].chunk(tp_size, 0) + for i in range(tp_size): + states[i]["model.norm.weight"] = hf_state["model.norm.weight"].clone() + states[i]["model.output.weight"] = outputs[i].clone() + + mlp_ratio = round((config.intermediate_size - 255) / config.hidden_size + 0.01, 2) + if "rotary" in config.to_dict(): + rope_base = config.rotary["base"] + elif "rope_theta" in config.to_dict(): + rope_base = config.rope_theta + else: + rope_base = 10000 + model_config = dict( + num_attention_heads=n_heads, + embed_split_hidden=embed_split_hidden, + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + num_layers=config.num_hidden_layers, + norm_type="rmsnorm", + layer_norm_epsilon=config.rms_norm_eps, + no_bias=True, + mlp_ratio=mlp_ratio, + num_kv_attention_heads=n_kv_heads, + dtype=config.torch_dtype, + # norm_head=False, + adapt_hf=adapt_hf, + use_flash_attn=use_flash, + rope_base=rope_base, + num_experts=config.num_experts, + moe_gate_k=config.num_experts_per_token, + ) + print("Model Config:", model_config) + + # split + os.makedirs(tgt, exist_ok=True) + print(f"Saving to {tgt}...") + for tp in tqdm(range(tp_size)): + torch.save(states[tp], os.path.join(tgt, f"model_tp{tp}_pp0.pt")) + for moe_layer_id in range(config.num_hidden_layers): + for expert_id in range(config.num_experts): + for tp in tqdm(range(tp_size)): + torch.save( + moe_states[moe_layer_id][expert_id][tp], + os.path.join(tgt, f"model_moe_layer{moe_layer_id}_expert{expert_id}_tp{tp}.pt"), + ) + torch.save(model_config, os.path.join(tgt, "model_config.pt")) + + +def print_args(args): + print("-------------- Arguments --------------") + print(f"Source Path: {args.src}") + print(f"Target Path: {args.tgt}") + print(f"TP Size: {args.tp_size}") + print(f"Embeb Split Hidden: {args.embed_split}") + print(f"Adapt HF: {args.adapt_hf}") + print(f"Use Flash Attn: {args.use_flash}") + print("---------------------------------------") + + +def parse_args(): + parser = argparse.ArgumentParser() + # model + parser.add_argument("--src", type=str, help="Input folder") + parser.add_argument("--tgt", type=str, help="Output folder") + parser.add_argument("--tp_size", type=int, help="world_size of tensor parallel") + parser.add_argument("--embed_split", action="store_true", help="embed_split_hidden of InternLM") + parser.add_argument("--adapt_hf", action="store_true", help="adapt_hf of InternLM") + parser.add_argument("--use_flash", action="store_true", help="use_flash_attn of InternLM") + parser.add_argument("--version", type=int, help="Determine the relavance between w2, w3 and up_gate, down_fate.") + + args = parser.parse_args() + + return args + + +# download ckpt from https://huggingface.co/DiscoResearch/mixtral-7b-8expert and +# srun -p llm_s python tools/transformers/mixtral2llamamoe.py --src ./ckpt/mixtral-7b-8expert/ --tgt ckpt --tp_size {tp} +if __name__ == "__main__": + args = parse_args() + print_args(args) + + revert(args.src, args.tgt, args.tp_size, args.embed_split, args.adapt_hf, args.use_flash)