support hf llama

pull/532/head
lijiaxing 2023-12-08 20:13:34 +08:00
parent 41edd074a6
commit 4b7fa26d80
1 changed files with 40 additions and 17 deletions

View File

@ -5,21 +5,8 @@ from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn import flash_attn_varlen_kvpacked_func
from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,
FlashSelfAttention,
SelfAttention,
_update_kv_cache,
)
from flash_attn.modules.mlp import ParallelFusedMLP
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from torch import nn
# isort: off
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.initialize.initialize_tensor import (
@ -43,7 +30,21 @@ from internlm.utils.common import filter_kwargs
from internlm.utils.logger import get_logger
from internlm.utils.registry import MODEL_INITIALIZER
# isort: on
try:
from flash_attn import flash_attn_varlen_kvpacked_func
from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,
FlashSelfAttention,
SelfAttention,
_update_kv_cache,
)
from flash_attn.modules.mlp import ParallelFusedMLP
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
raise ImportError("Please check your flash_attn version >= 2.0.0.")
MODEL_TYPE = "LLAMA"
@ -58,6 +59,7 @@ class MHA(nn.Module):
Args:
embed_dim (int): The dimention of hidden state.
num_heads (int): The number of attention heads.
num_kv_heads (int): The number of kv attention heads.
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
output projection. True by default.
@ -65,15 +67,16 @@ class MHA(nn.Module):
softmax_scale (float): The temperature to use for the softmax attention.
causal (boolean): Whether to apply causal attention mask. False by default.
layer_idx (int): The index of current layer. None by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used.
False by default.
True by default.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
use_flash_attn (bool): Whether to use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
rot_embed_HF_impl: rotary embedding hf implementation. False by default.
"""
@ -461,6 +464,7 @@ class PackedFlashLlamaLayer1D(nn.Module):
Args:
hidden_size (int): The hidden size of model. 768 by default.
num_attention_heads (int): The number of attention heads. 12 by default.
num_kv_attention_heads (int): The number of kv attention heads. 12 by default.
mlp_ratio (int): The ratio of MLP layers. 4 by default.
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
@ -470,7 +474,14 @@ class PackedFlashLlamaLayer1D(nn.Module):
layer_idx (int): The index of current layer. 0 by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
device (Optional[Union[str, torch.device]]): The device will be used.
apply_post_layer_norm (bool): Whether use post layer norm. False by default.
fused_dropout_add_ln (bool): Whether use fused dropout add ln. True by default.
no_bias (bool): Whether remove bias. False by default.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
adapt_hf (bool): Whether adapt hf. False by default.
dropout_selective_checkpoint (bool): Whether use dropout selective checkpoint. True by default.
use_scaled_init (bool): Whether use scaled init. True by default.
use_swiglu (bool): Whether use swiglu. True by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
@ -735,6 +746,7 @@ class PackedFlashLlama1D(nn.Module):
num_layers (int): The number of layer. 12 by default.
hidden_size (int): The size of hidden state. 768 by default.
num_attention_heads (int): The number of attention head. 12 by default.
num_kv_attention_heads (int): The number of kv attention head. 12 by default.
vocab_size (int): The size of vocabulary. 50304 by default.
mlp_ratio (int): The ratio of MLP layers. 4 by default.
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
@ -752,8 +764,15 @@ class PackedFlashLlama1D(nn.Module):
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
apply_post_layer_norm (bool): Whether use post layer norm. False by default.
no_bias (bool): Whether remove bias. False by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
adapt_hf (bool): Whether adapt hf. False by default.
is_reward (bool): Whether use is_reward. False by default.
dropout_selective_checkpoint (bool): Whether dropout selective checkpoint. True by default.
use_scaled_init (bool): Whether use scaled init. True by default.
use_swiglu (bool): Whether use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
embedding_init_std (float): std used to init embedding weight. 0.02 by default,
attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
@ -1085,14 +1104,18 @@ def build_model_with_cfg(
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
num_attention_heads (int): The number of attention head. 32 by default.
num_kv_attention_heads (int): The number of kv attention head. None by default.
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
because this parameter requires inconsistent data types to be passed between pipelines,
which requires significant modifications to internlm.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
adapt_hf (bool): Whether adapt hf. False by default.
drop_rate (float): The dropout rate of input hidden state. 0 by default.
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
no_bias (bool): Whether remove bias. False by default.
deepnorm (bool): Whether us deepnorm. False by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
is_reward (bool): Whether to use reward model. False by default.
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.