# Copyright (c) InternLM. All rights reserved. import math from typing import Optional import torch import torch.nn.functional as F from einops import rearrange from torch import nn from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_fp32_attr_to_module from internlm.initialize.initialize_tensor import ( normal_, scaled_init_method_normal, scaled_init_method_uniform, uniform_, ) from internlm.model.embedding import Embedding1D, RotaryEmbedding from internlm.model.linear import ( ColumnParallelLinearTorch, FeedForward, RewardModelLinear, RowParallelLinearTorch, ScaleColumnParallelLinear, ) from internlm.model.moe import MoE from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm from internlm.solver.pipeline_utils import partition_uniform from internlm.utils.checkpoint import activation_checkpoint from internlm.utils.common import filter_kwargs from internlm.utils.logger import get_logger from internlm.utils.registry import MODEL_INITIALIZER try: from flash_attn import flash_attn_varlen_kvpacked_func from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.mha import ( CrossAttention, FlashCrossAttention, FlashSelfAttention, SelfAttention, _update_kv_cache, ) from flash_attn.modules.mlp import ParallelFusedMLP from flash_attn.ops.layer_norm import dropout_add_layer_norm except ImportError: pass MODEL_TYPE = "LLAMA_MoE" logger = get_logger(__file__) RMSNorm = try_import_RMSNorm() class MHA(nn.Module): """ Multi-head self-attention and cross-attention. Args: embed_dim (int): The dimention of hidden state. num_heads (int): The number of attention heads. num_kv_heads (int): The number of kv attention heads. process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and output projection. True by default. dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. softmax_scale (float): The temperature to use for the softmax attention. causal (boolean): Whether to apply causal attention mask. False by default. layer_idx (int): The index of current layer. None by default. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used. True by default. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. rot_embed_HF_impl: rotary embedding hf implementation. False by default. """ def __init__( self, embed_dim: int, num_heads: int, num_kv_heads: int, process_group: Optional[torch.distributed.ProcessGroup], bias: bool = True, dropout: float = 0.0, softmax_scale: float = None, causal: bool = False, layer_idx: int = None, rope_base: int = 10000, rotary_emb_dim: int = 0, rotary_emb_scale_base: int = 0, use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, rot_embed_HF_impl: Optional[bool] = False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads self.num_kv_heads = num_kv_heads self.kv_dim = self.head_dim * num_kv_heads self.causal = causal self.layer_idx = layer_idx self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.dtype = dtype self.rot_embed_HF_impl = rot_embed_HF_impl sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) if self.rotary_emb_dim > 0: self.rotary_emb = RotaryEmbedding( self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device ) # notice here should change bias=True self.wq = ColumnParallelLinearTorch( embed_dim, embed_dim, process_group, bias=bias, sequence_parallel=sequence_parallel, **factory_kwargs, ) self.wk = ColumnParallelLinearTorch( embed_dim, self.kv_dim, process_group, bias=bias, sequence_parallel=sequence_parallel, **factory_kwargs, ) self.wv = ColumnParallelLinearTorch( embed_dim, self.kv_dim, process_group, bias=bias, sequence_parallel=sequence_parallel, **factory_kwargs, ) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) self.inner_cross_attn_causal = causal self.inner_cross_attn_softmax_scale = softmax_scale self.inner_cross_attn_dropout = dropout # output projection always have the bias (for now) self.wo = RowParallelLinearTorch( embed_dim, embed_dim, process_group, bias=bias, sequence_parallel=sequence_parallel, **factory_kwargs, ) # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: for name in ["wo", "wq", "wk", "wv"]: for param in getattr(self, name).parameters(): setattr(param, IS_TENSOR_PARALLEL, True) def forward(self, x, seqlen=None, inference_params=None, **kwargs): if kwargs.get("indexes", None) is not None: return self._packed_forward(x=x, inference_params=inference_params, **kwargs) else: return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ bsz, _, _ = x.shape q, k, v = self.wq(x), self.wk(x), self.wv(x) if seqlen is None: q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) else: q = rearrange(q, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) k = rearrange(k, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) v = rearrange(v, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) if not self.rot_embed_HF_impl: q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) if inference_params is None: if self.rotary_emb_dim > 0: q = self.rotary_emb._single_eval_forward(q) k = self.rotary_emb._single_eval_forward(k) kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) if self.dtype is torch.float32 and self.use_flash_attn: if q.dtype not in [torch.float16, torch.bfloat16]: q = q.to(torch.bfloat16) if kv.dtype not in [torch.float16, torch.bfloat16]: kv = kv.to(torch.bfloat16) with torch.cuda.amp.autocast(dtype=torch.bfloat16): context = self.inner_cross_attn(q, kv).to(self.dtype) else: context = self.inner_cross_attn(q, kv) else: assert self.rotary_emb_dim > 0 if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: empties = inference_params.attention_mask[..., -1].sum(dim=-1) moved_q = q.clone() moved_k = k.clone() if inference_params.sequence_len_offset == 0: for i in range(len(empties)): if empties[i] != 0: moved_q[i][: -empties[i]] = q[i][empties[i] :] moved_k[i][: -empties[i]] = k[i][empties[i] :] moved_q = self.rotary_emb._single_eval_forward( moved_q, seqlen_offset=inference_params.sequence_len_offset ) moved_k = self.rotary_emb._single_eval_forward( moved_k, seqlen_offset=inference_params.sequence_len_offset ) for i in range(len(empties)): if empties[i] != 0: q[i][empties[i] :] = moved_q[i][: -empties[i]] k[i][empties[i] :] = moved_k[i][: -empties[i]] else: q[i] = moved_q[i] k[i] = moved_k[i] else: q = q.squeeze(1) k = k.squeeze(1) q = self.rotary_emb._single_forward( q, inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties, ).unsqueeze(1) k = self.rotary_emb._single_forward( k, inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties, ).unsqueeze(1) else: raise NotImplementedError( "You should make sure you are aware that you are changing the method of generating." "According to your generation function instead of inference/seq_generator_module.py, " "You may implement here for normal running." ) kv = torch.stack([k, v], dim=2) assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" if hasattr(inference_params, "window_size") and inference_params.window_size is not None: if inference_params.window_size <= inference_params.sequence_len_offset: assert kv.size(1) == 1, "update kv lenth more than 1" inference_params.key_value_memory_dict[self.layer_idx][ :, inference_params.keep_first : inference_params.window_size - 1, ... ] = inference_params.key_value_memory_dict[self.layer_idx][ :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... ].clone() inference_params.real_sequence_len_offset = inference_params.sequence_len_offset inference_params.sequence_len_offset = inference_params.window_size - 1 kv = _update_kv_cache(kv, inference_params, self.layer_idx) inference_params.sequence_len_offset = inference_params.real_sequence_len_offset else: kv = _update_kv_cache(kv, inference_params, self.layer_idx) else: kv = _update_kv_cache(kv, inference_params, self.layer_idx) # When using FP16, there is a high probability of NAN in the KV. # Since NAN cannot be removed by multiplying with and 0, it needs # to be removed manually here. kv = torch.where(torch.isnan(kv), 0, kv) if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: assert self.use_flash_attn is True if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) attn_mask = inference_params.attention_mask[:, None, ...] attn_mask = torch.logical_or( torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask ) attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) cu_seqlens = torch.concat( [ torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), ], dim=0, ) cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) max_seqlen_q = attn_mask4flsh.shape[-1] max_seqlen_k = attn_mask4flsh.shape[-1] total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] ) if self.dtype is torch.float32: if total_q.dtype not in [torch.float16, torch.bfloat16]: total_q = total_q.to(torch.bfloat16) if total_kv.dtype not in [torch.float16, torch.bfloat16]: total_kv = total_kv.to(torch.bfloat16) with torch.cuda.amp.autocast(dtype=torch.bfloat16): output = FlashAttnVarlenKVPackedFunc.apply( total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False, ).to(self.dtype) else: output = FlashAttnVarlenKVPackedFunc.apply( total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False, ) context = torch.zeros_like(q) context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) else: attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) if hasattr(inference_params, "window_size") and inference_params.window_size is not None: if inference_params.window_size <= inference_params.sequence_len_offset: attn_mask = torch.concat( [ attn_mask[..., : inference_params.keep_first], attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :], ], dim=-1, ) k, v = torch.chunk(kv, 2, dim=2) k = k.squeeze(2) v = v.squeeze(2) sp = k.shape expansion = q.size(2) // k.size(2) scores = torch.einsum( "blhd,bnhd->bhln", q, k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), ) / math.sqrt(q.size(-1)) scores = scores.masked_fill(attn_mask, -65000.0) scores = F.softmax(scores, dim=-1) # bsz x h x L x L context = torch.einsum( "bhmn,bnhd->bmhd", scores, v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), ) else: if self.dtype is torch.float32 and self.use_flash_attn: if q.dtype not in [torch.float16, torch.bfloat16]: q = q.to(torch.bfloat16) if kv.dtype not in [torch.float16, torch.bfloat16]: kv = kv.to(torch.bfloat16) with torch.cuda.amp.autocast(dtype=torch.bfloat16): context = self.inner_cross_attn(q, kv, causal=True).to(self.dtype) else: context = self.inner_cross_attn(q, kv, causal=True) if seqlen is None: context = rearrange(context, "b s h d -> b s (h d)") else: context = rearrange(context, "b s h d -> (b s) (h d)") out = self.wo(context) return out def _packed_forward(self, x, inference_params=None, **kwargs): """ we delete seqlen=None for lint check, cause this arg is not used. Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ assert self.use_flash_attn is True q, k, v = self.wq(x), self.wk(x), self.wv(x) q = rearrange(q, "t (h d) -> t h d", d=self.head_dim) k = rearrange(k, "t (h d) -> t h d", d=self.head_dim) v = rearrange(v, "t (h d) -> t h d", d=self.head_dim) # qkv shift # the rotary embedding in flash attention module in performed by separating the front and back parts, while # most of others are done by odd-even methods. if not self.rot_embed_HF_impl: q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) indexes = kwargs.pop("indexes") q = self.rotary_emb._single_forward(q, indexes=indexes) k = self.rotary_emb._single_forward(k, indexes=indexes) if inference_params is None: kv = torch.concat([k.unsqueeze(1), v.unsqueeze(1)], dim=1) if self.dtype is torch.float32: if q.dtype not in [torch.float16, torch.bfloat16]: q = q.to(torch.bfloat16) if kv.dtype not in [torch.float16, torch.bfloat16]: kv = kv.to(torch.bfloat16) with torch.cuda.amp.autocast(dtype=torch.bfloat16): context = flash_attn_varlen_kvpacked_func( q=q, kv=kv, cu_seqlens_q=kwargs["cu_seqlens"], cu_seqlens_k=kwargs["cu_seqlens"], max_seqlen_q=kwargs["max_seqlen"], max_seqlen_k=kwargs["max_seqlen"], dropout_p=self.inner_cross_attn_dropout, softmax_scale=self.inner_cross_attn_softmax_scale, causal=self.inner_cross_attn_causal, ).to(self.dtype) else: context = flash_attn_varlen_kvpacked_func( q=q, kv=kv, cu_seqlens_q=kwargs["cu_seqlens"], cu_seqlens_k=kwargs["cu_seqlens"], max_seqlen_q=kwargs["max_seqlen"], max_seqlen_k=kwargs["max_seqlen"], dropout_p=self.inner_cross_attn_dropout, softmax_scale=self.inner_cross_attn_softmax_scale, causal=self.inner_cross_attn_causal, ) else: raise RuntimeError("Not support this right now") context = rearrange(context, "b h d -> b (h d)") # recover shape out = self.wo(context) return out class PackedFlashLlamaLayer1D(nn.Module): """ 1D Packed Flash Llama Layer. Args: hidden_size (int): The hidden size of model. 768 by default. num_attention_heads (int): The number of attention heads. 12 by default. num_kv_attention_heads (int): The number of kv attention heads. 12 by default. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. dtype (torch.dtype): Type of data. torch.float by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. layer_idx (int): The index of current layer. 0 by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. apply_post_layer_norm (bool): Whether use post layer norm. False by default. fused_dropout_add_ln (bool): Whether use fused dropout add ln. True by default. no_bias (bool): Whether remove bias. False by default. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. adapt_hf (bool): Whether adapt hf. False by default. dropout_selective_checkpoint (bool): Whether use dropout selective checkpoint. True by default. use_scaled_init (bool): Whether use scaled init. True by default. use_swiglu (bool): Whether use swiglu. True by default. use_flash_attn (bool): Whether use flash-attn. True by default. attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu otherwise init fc1 weight in ffn. 0.02 by default, ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. """ def __init__( self, hidden_size: int = 768, num_attention_heads: int = 12, num_kv_attention_heads: int = 12, mlp_ratio: int = 4, attn_drop_rate: float = 0, drop_rate: float = 0.0, dtype: torch.dtype = torch.float, layer_norm_epsilon: float = 1e-6, checkpoint: bool = False, layer_idx: int = 0, residual_in_fp32: bool = False, device: Optional[torch.device] = None, apply_post_layer_norm: bool = False, fused_dropout_add_ln: bool = True, no_bias: bool = False, norm_type: str = "rmsnorm", adapt_hf: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, ffn_uplayer_init_std: float = 0.02, ffn_other_init_std: float = 0.02, init_type: str = "normal", rope_base: int = 10000, num_experts: int = 1, moe_gate_k: int = 1, moe_capacity_factor: float = 1.0, moe_eval_capacity_factor: float = 1.0, moe_min_capacity: int = 4, moe_noisy_gate_policy: str = None, moe_drop_tokens: bool = True, moe_use_rts: bool = True, moe_use_residual: bool = False, ): super().__init__() self.checkpoint = checkpoint # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx self.use_flash_attn = use_flash_attn self.prenorm = not apply_post_layer_norm assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" self.fused_dropout_add_ln = fused_dropout_add_ln self.attn_wqkv_init_std = attn_wqkv_init_std self.attn_other_init_std = attn_other_init_std self.ffn_uplayer_init_std = ffn_uplayer_init_std self.ffn_other_init_std = ffn_other_init_std head_dim = hidden_size // num_attention_heads self.attention = MHA( embed_dim=hidden_size, num_heads=num_attention_heads, num_kv_heads=num_kv_attention_heads, process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, softmax_scale=1 / math.sqrt(head_dim), causal=True, layer_idx=layer_idx, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, use_flash_attn=use_flash_attn, device=device, dtype=dtype, rot_embed_HF_impl=adapt_hf, bias=not no_bias, rope_base=rope_base, ) self.dropout1 = nn.Dropout(drop_rate) if norm_type == "rmsnorm": self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) else: self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) set_fp32_attr_to_module(self.attention_norm) set_fp32_attr_to_module(self.ffn_norm) if self.fused_dropout_add_ln: assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed" assert isinstance(self.attention_norm, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) self.num_experts = num_experts self.moe_gate_k = moe_gate_k self.moe_capacity_factor = moe_capacity_factor self.moe_eval_capacity_factor = moe_eval_capacity_factor self.moe_min_capacity = moe_min_capacity self.moe_noisy_gate_policy = moe_noisy_gate_policy self.moe_drop_tokens = moe_drop_tokens self.moe_use_rts = moe_use_rts self.moe_use_residual = moe_use_residual ep_size = gpc.get_world_size(ParallelMode.EXPERT) if num_experts <= 1: # dense, not MoE if use_swiglu: self.feed_forward = FeedForward( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, ) else: self.feed_forward = ParallelFusedMLP( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, activation="gelu_approx", process_group=gpc.get_group(ParallelMode.TENSOR), bias1=False, bias2=False, sequence_parallel=sequence_parallel, checkpoint_lvl=0, heuristic="auto", device=device, dtype=dtype, ) for _, param in self.feed_forward.named_parameters(): if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) else: # replace mlp by MoE module. The expert in MoE is a FeedForward module. self.feed_forward = MoE( hidden_size=hidden_size, num_experts=num_experts, ep_size=ep_size, k=moe_gate_k, capacity_factor=moe_capacity_factor, eval_capacity_factor=moe_eval_capacity_factor, min_capacity=moe_min_capacity, noisy_gate_policy=moe_noisy_gate_policy, drop_tokens=moe_drop_tokens, use_rts=moe_use_rts, use_residual=moe_use_residual, device=device, dtype=dtype, ) for _, param in self.feed_forward.moe_layer.experts.named_parameters(): if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) set_fp32_attr_to_module(self.feed_forward.moe_layer.gate) for param in self.attention_norm.parameters(): if gpc.config.parallel.sequence_parallel is True: setattr(param, IS_SEQUENCE_PARALLEL, True) for param in self.ffn_norm.parameters(): if gpc.config.parallel.sequence_parallel is True: setattr(param, IS_SEQUENCE_PARALLEL, True) self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm self.return_residual = False if init_type == "normal": self.init_func = normal_ self.scaled_init_func = scaled_init_method_normal else: self.init_func = uniform_ self.scaled_init_func = scaled_init_method_uniform self.reset_parameters() def reset_parameters(self): with torch.no_grad(): for name, param in self.attention.named_parameters(): if param.ndim == 1: param.data.zero_() elif "wq" in name or "wk" in name or "wv" in name: self.init_func(std=self.attn_wqkv_init_std)(param.data) elif self.use_scaled_init: # wo self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) else: self.init_func(std=self.attn_other_init_std)(param.data) for name, param in self.feed_forward.named_parameters(): if self.use_swiglu: if self.use_scaled_init and "w2" in name: self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) else: self.init_func( std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std )(param.data) else: if self.use_scaled_init and "fc1" not in name: self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) else: self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( param.data ) def forward( self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None ): if self.checkpoint and self.training: return activation_checkpoint( self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen ) else: return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen) def _forward( self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None ): r"""Pass the input through the encoder layer. Args: hidden_states: the sequence to the encoder layer (required). residual: hidden_states = Attn/MLP(LN(residual)) cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ if self.prenorm: def _dropout_and_norm_attn(_residual, _hidden_states): _dropped = self.dropout1(_hidden_states) _residual = (_dropped + _residual) if _residual is not None else _dropped _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) return _residual, _hidden_states if self.dropout_selective_checkpoint: residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) else: residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) if self.residual_in_fp32: residual = residual.to(torch.float32) mixer_kwargs = { "cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, "indexes": indexes, "inference_params": inference_params, } hidden_states = self.attention(hidden_states, **mixer_kwargs) if not isinstance(self.feed_forward, nn.Identity): if not self.fused_dropout_add_ln: def _dropout_and_norm_ffn(_residual, _hidden_states): _dropped = self.dropout2(_hidden_states) _residual = (_dropped + _residual) if _residual is not None else _dropped _hidden_states = self.ffn_norm(_residual.to(torch.float32)) return _residual, _hidden_states if self.dropout_selective_checkpoint: residual, hidden_states = activation_checkpoint( _dropout_and_norm_ffn, False, residual, hidden_states ) else: residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) if self.residual_in_fp32: residual = residual.to(torch.float32) # MLP. moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) if self.num_experts <= 1: # dense mlp output hidden_states = self.feed_forward(hidden_states) else: # MoE output hidden_states, moe_loss, _ = self.feed_forward(hidden_states) return hidden_states + residual, moe_loss else: assert residual is None mixer_kwargs = { "cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, "indexes": indexes, "inference_params": inference_params, } mixer_out = self.attention(hidden_states, **mixer_kwargs) if self.return_residual: # mixer out is actually a pair here mixer_out, hidden_states = mixer_out hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( dtype=self.attention_norm.weight.dtype ) if not isinstance(self.feed_forward, nn.Identity): # MLP. moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) if self.num_experts <= 1: # dense mlp output mlp_out = self.feed_forward(hidden_states) else: # MoE output mlp_out, moe_loss, _ = self.feed_forward(hidden_states) if self.return_residual: # mlp out is actually a pair here # NOTE: should not be here mlp_out, hidden_states = mlp_out hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( dtype=self.ffn_norm.weight.dtype ) return hidden_states, moe_loss class PackedFlashLlama1D(nn.Module): """ 1D Packed Flash Llama. Args: num_layers (int): The number of layer. 12 by default. hidden_size (int): The size of hidden state. 768 by default. num_attention_heads (int): The number of attention head. 12 by default. num_kv_attention_heads (int): The number of kv attention head. 12 by default. vocab_size (int): The size of vocabulary. 50304 by default. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. dtype (torch.dtype): The type of data. torch.float by default. checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number of layers. 1.0 by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. apply_post_layer_norm (bool): Whether use post layer norm. False by default. no_bias (bool): Whether remove bias. False by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. adapt_hf (bool): Whether adapt hf. False by default. is_reward (bool): Whether use is_reward. False by default. dropout_selective_checkpoint (bool): Whether dropout selective checkpoint. True by default. use_scaled_init (bool): Whether use scaled init. True by default. use_swiglu (bool): Whether use swiglu. True by default. use_flash_attn (bool): Whether to use flash-attn. True by default. embedding_init_std (float): std used to init embedding weight. 0.02 by default, attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu otherwise init fc1 weight in ffn. 0.02 by default, ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. """ def __init__( self, num_layers: int = 12, hidden_size: int = 768, num_attention_heads: int = 12, num_kv_attention_heads: int = 12, vocab_size: int = 50304, mlp_ratio: int = 4, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, dtype: torch.dtype = torch.float, checkpoint: bool = False, checkpoint_fraction: float = 1.0, layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, device: Optional[torch.device] = None, apply_post_layer_norm=False, no_bias=False, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", adapt_hf: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, ffn_uplayer_init_std: float = 0.02, ffn_other_init_std: float = 0.02, out_head_init_std: float = 0.02, init_type: str = "normal", rope_base: int = 10000, num_experts: bool = 1, moe_gate_k: int = 1, moe_capacity_factor: float = 1.0, moe_eval_capacity_factor: float = 1.0, moe_min_capacity: int = 4, moe_noisy_gate_policy: str = None, moe_drop_tokens: bool = True, moe_use_rts: bool = True, moe_use_residual: bool = False, ): super().__init__() self.use_flash_attn = use_flash_attn if checkpoint_fraction <= 0: checkpoint = False if not checkpoint: checkpoint_fraction = 0 checkpoint_layer_num = num_layers * checkpoint_fraction sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) if is_reward: head_cls = RewardModelLinear else: head_cls = ScaleColumnParallelLinear if first: if embed_split_hidden: self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) else: self.tok_embeddings = ParallelGPT2Embeddings( embed_dim=hidden_size, vocab_size=vocab_size, max_position_embeddings=-1, process_group=gpc.get_group(ParallelMode.TENSOR), padding_idx=None, sequence_parallel=sequence_parallel, device=device, dtype=dtype, ) for _, param in self.tok_embeddings.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: uniform_(std=embedding_init_std)(param) if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) self.embed_grad_scale = embed_grad_scale self.layers = nn.ModuleList( [ PackedFlashLlamaLayer1D( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_kv_attention_heads=num_kv_attention_heads, mlp_ratio=mlp_ratio, attn_drop_rate=attn_drop_rate, drop_rate=drop_rate, dtype=dtype, layer_norm_epsilon=layer_norm_epsilon, checkpoint=lid < checkpoint_layer_num, layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation residual_in_fp32=residual_in_fp32, device=device, apply_post_layer_norm=apply_post_layer_norm, fused_dropout_add_ln=False, no_bias=no_bias, norm_type=norm_type, dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, adapt_hf=adapt_hf, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, ffn_other_init_std=ffn_other_init_std, init_type=init_type, rope_base=rope_base, num_experts=num_experts, moe_gate_k=moe_gate_k, moe_capacity_factor=moe_capacity_factor, moe_eval_capacity_factor=moe_eval_capacity_factor, moe_min_capacity=moe_min_capacity, moe_noisy_gate_policy=moe_noisy_gate_policy, moe_drop_tokens=moe_drop_tokens, moe_use_rts=moe_use_rts, moe_use_residual=moe_use_residual, ) for lid in range(num_layers) ] ) if last: if not apply_post_layer_norm: if norm_type == "rmsnorm": self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) else: self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) for param in self.norm.parameters(): if gpc.config.parallel.sequence_parallel is True: setattr(param, IS_SEQUENCE_PARALLEL, True) self.output = head_cls( in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, weight_scale=embed_grad_scale, ) for _, param in self.output.named_parameters(): if init_type == "normal": normal_(std=out_head_init_std)(param) else: uniform_(std=out_head_init_std)(param) if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) self.parallel_output = parallel_output def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "tok_embeddings"): hidden_states = self.tok_embeddings(input_ids) if self.embed_grad_scale != 1: hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) if isinstance(cu_seqlens, list): assert len(cu_seqlens) == 1 cu_seqlens = cu_seqlens[0].to(hidden_states.device) if cu_seqlens is not None: cu_seqlens = cu_seqlens.squeeze(0) hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state, # the batch dimension with a size of 1 should be directly squeezed off. if indexes is not None: assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None moe_losses = [] for _, block in enumerate(self.layers): hidden_states, mos_loss = block( hidden_states, residual=None, cu_seqlens=cu_seqlens, indexes=indexes, inference_params=inference_params, max_seqlen=max_seqlen, ) moe_losses.append(mos_loss) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) extra_hidden_states_list = None if hasattr(self, "output"): hidden_states = self.output(hidden_states) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) if extra_hidden_states_list is not None: extra_hidden_states_list = [ gather_forward_split_backward(extra_hidden_states, ParallelMode.TENSOR, dim=-1) for extra_hidden_states in extra_hidden_states_list # pylint: disable=E1133 ] if extra_hidden_states_list is not None: return (hidden_states, extra_hidden_states_list) return hidden_states, moe_losses def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): """ build generic model 1d Args: num_layers (int): The number of layer. num_chunks (int): The number of partitions in pipeline parallel. device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. """ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) parts = all_parts[pipeline_rank] if gpc.is_rank_for_log(): logger.info(f"The layer sharding is {all_parts}.") models = [] kwargs["checkpoint_fraction"] = 1.0 start_idx, end_idx = 0, 0 for start, end in parts: start_idx, end_idx = start, end kwargs["num_layers"] = end - start kwargs["first"] = start == 0 # If there is no content in the final layer, assign the last layer. kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 kwargs["device"] = device kwargs["start_layer_idx"] = start chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device) models.append(chunk) torch.distributed.barrier() if len(models) == 1: model = models[0] else: model = nn.ModuleList(models) setattr(model, "first_layer", start_idx) setattr(model, "last_layer", end_idx) return model @MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) def build_model_with_moe_cfg( num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, num_layers=48, hidden_size=2048, vocab_size=50304, embed_grad_scale=1, parallel_output=True, num_attention_heads=32, num_kv_attention_heads=None, mlp_ratio=4.0, residual_in_fp32=False, norm_type="rmsnorm", adapt_hf=False, drop_rate=0, attn_drop_rate=0, apply_post_layer_norm=False, # pylint: disable=W0613 no_bias=False, deepnorm=False, # pylint: disable=W0613 layer_norm_epsilon=1e-5, is_reward=False, dropout_selective_checkpoint=True, use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, ffn_uplayer_init_std: float = 0.02, ffn_other_init_std: float = 0.02, out_head_init_std: float = 0.02, init_type: str = "normal", rope_base: int = 10000, num_experts: int = 1, moe_gate_k: int = 1, moe_capacity_factor: float = 1.0, moe_eval_capacity_factor: float = 1.0, moe_min_capacity: int = 4, moe_noisy_gate_policy: str = None, moe_drop_tokens: bool = True, moe_use_rts: bool = True, moe_use_residual: bool = False, ): """ Build model with config. Args: num_chunks (int): The number of partitions in pipeline parallel. 1 by default. checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. dtype (torch.dtype): The type of data. torch.float by default. embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. False by default. num_layers (int): The number of layer. 48 by default. hidden_size (int): The size of hidden state. 2048 by default. vocab_size (int): The size of vocabulary. 50304 by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. num_attention_heads (int): The number of attention head. 32 by default. mlp_ratio (int): The ratio of MLP layers. 4.0 by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily because this parameter requires inconsistent data types to be passed between pipelines, which requires significant modifications to internlm. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. drop_rate (float): The dropout rate of input hidden state. 0 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. is_reward (bool): Whether to use reward model. False by default. dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. use_scaled_init (bool): Whether to use scaled init. True by default. use_swiglu (bool): Whether to use swiglu. True by default. use_flash_attn (bool): Whether to use flash-attn. True by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. """ cfg = dict( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, checkpoint=checkpoint, dtype=dtype, embed_split_hidden=embed_split_hidden, vocab_size=vocab_size, embed_grad_scale=embed_grad_scale, parallel_output=parallel_output, mlp_ratio=mlp_ratio, apply_post_layer_norm=apply_post_layer_norm, no_bias=no_bias, residual_in_fp32=residual_in_fp32, norm_type=norm_type, adapt_hf=adapt_hf, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, layer_norm_epsilon=layer_norm_epsilon, is_reward=is_reward, dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, embedding_init_std=embedding_init_std, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, ffn_other_init_std=ffn_other_init_std, out_head_init_std=out_head_init_std, init_type=init_type, rope_base=rope_base, num_experts=num_experts, moe_gate_k=moe_gate_k, moe_capacity_factor=moe_capacity_factor, moe_eval_capacity_factor=moe_eval_capacity_factor, moe_min_capacity=moe_min_capacity, moe_noisy_gate_policy=moe_noisy_gate_policy, moe_drop_tokens=moe_drop_tokens, moe_use_rts=moe_use_rts, moe_use_residual=moe_use_residual, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)