mirror of https://github.com/InternLM/InternLM
support hf llama
parent
41edd074a6
commit
4b7fa26d80
|
@ -5,21 +5,8 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mha import (
|
||||
CrossAttention,
|
||||
FlashCrossAttention,
|
||||
FlashSelfAttention,
|
||||
SelfAttention,
|
||||
_update_kv_cache,
|
||||
)
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
from torch import nn
|
||||
|
||||
# isort: off
|
||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.initialize.initialize_tensor import (
|
||||
|
@ -43,7 +30,21 @@ from internlm.utils.common import filter_kwargs
|
|||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
# isort: on
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mha import (
|
||||
CrossAttention,
|
||||
FlashCrossAttention,
|
||||
FlashSelfAttention,
|
||||
SelfAttention,
|
||||
_update_kv_cache,
|
||||
)
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
raise ImportError("Please check your flash_attn version >= 2.0.0.")
|
||||
|
||||
MODEL_TYPE = "LLAMA"
|
||||
|
||||
|
@ -58,6 +59,7 @@ class MHA(nn.Module):
|
|||
Args:
|
||||
embed_dim (int): The dimention of hidden state.
|
||||
num_heads (int): The number of attention heads.
|
||||
num_kv_heads (int): The number of kv attention heads.
|
||||
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
|
||||
bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
|
||||
output projection. True by default.
|
||||
|
@ -65,15 +67,16 @@ class MHA(nn.Module):
|
|||
softmax_scale (float): The temperature to use for the softmax attention.
|
||||
causal (boolean): Whether to apply causal attention mask. False by default.
|
||||
layer_idx (int): The index of current layer. None by default.
|
||||
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
|
||||
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
|
||||
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
|
||||
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
|
||||
use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used.
|
||||
False by default.
|
||||
True by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
|
||||
rot_embed_HF_impl: rotary embedding hf implementation. False by default.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
@ -461,6 +464,7 @@ class PackedFlashLlamaLayer1D(nn.Module):
|
|||
Args:
|
||||
hidden_size (int): The hidden size of model. 768 by default.
|
||||
num_attention_heads (int): The number of attention heads. 12 by default.
|
||||
num_kv_attention_heads (int): The number of kv attention heads. 12 by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
||||
drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
|
||||
|
@ -470,7 +474,14 @@ class PackedFlashLlamaLayer1D(nn.Module):
|
|||
layer_idx (int): The index of current layer. 0 by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
apply_post_layer_norm (bool): Whether use post layer norm. False by default.
|
||||
fused_dropout_add_ln (bool): Whether use fused dropout add ln. True by default.
|
||||
no_bias (bool): Whether remove bias. False by default.
|
||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||
adapt_hf (bool): Whether adapt hf. False by default.
|
||||
dropout_selective_checkpoint (bool): Whether use dropout selective checkpoint. True by default.
|
||||
use_scaled_init (bool): Whether use scaled init. True by default.
|
||||
use_swiglu (bool): Whether use swiglu. True by default.
|
||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||
attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
|
||||
attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
|
||||
|
@ -735,6 +746,7 @@ class PackedFlashLlama1D(nn.Module):
|
|||
num_layers (int): The number of layer. 12 by default.
|
||||
hidden_size (int): The size of hidden state. 768 by default.
|
||||
num_attention_heads (int): The number of attention head. 12 by default.
|
||||
num_kv_attention_heads (int): The number of kv attention head. 12 by default.
|
||||
vocab_size (int): The size of vocabulary. 50304 by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
|
||||
|
@ -752,8 +764,15 @@ class PackedFlashLlama1D(nn.Module):
|
|||
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
||||
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
||||
apply_post_layer_norm (bool): Whether use post layer norm. False by default.
|
||||
no_bias (bool): Whether remove bias. False by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
adapt_hf (bool): Whether adapt hf. False by default.
|
||||
is_reward (bool): Whether use is_reward. False by default.
|
||||
dropout_selective_checkpoint (bool): Whether dropout selective checkpoint. True by default.
|
||||
use_scaled_init (bool): Whether use scaled init. True by default.
|
||||
use_swiglu (bool): Whether use swiglu. True by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
embedding_init_std (float): std used to init embedding weight. 0.02 by default,
|
||||
attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
|
||||
|
@ -1085,14 +1104,18 @@ def build_model_with_cfg(
|
|||
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
||||
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
||||
num_attention_heads (int): The number of attention head. 32 by default.
|
||||
num_kv_attention_heads (int): The number of kv attention head. None by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
|
||||
because this parameter requires inconsistent data types to be passed between pipelines,
|
||||
which requires significant modifications to internlm.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
adapt_hf (bool): Whether adapt hf. False by default.
|
||||
drop_rate (float): The dropout rate of input hidden state. 0 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
||||
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
|
||||
no_bias (bool): Whether remove bias. False by default.
|
||||
deepnorm (bool): Whether us deepnorm. False by default.
|
||||
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
||||
is_reward (bool): Whether to use reward model. False by default.
|
||||
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
||||
|
|
Loading…
Reference in New Issue