mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
696 lines
30 KiB
696 lines
30 KiB
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
|
import itertools
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch.distributed import ProcessGroup
|
|
from transformers.models.llama.modeling_llama import (
|
|
LlamaAttention,
|
|
LlamaConfig,
|
|
LlamaDecoderLayer,
|
|
LlamaForCausalLM,
|
|
LlamaMLP,
|
|
LlamaModel,
|
|
LlamaRMSNorm,
|
|
)
|
|
|
|
from colossalai.inference.config import InputMetaData
|
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
|
from colossalai.kernel.triton import (
|
|
context_attention_unpadded,
|
|
copy_k_to_blocked_cache,
|
|
decoding_fused_rotary_embedding,
|
|
flash_decoding_attention,
|
|
get_xine_cache,
|
|
rms_layernorm,
|
|
rotary_embedding,
|
|
)
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
|
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
|
|
|
|
inference_ops = InferenceOpsLoader().load()
|
|
|
|
logger = get_dist_logger(__name__)
|
|
|
|
try:
|
|
from flash_attn import flash_attn_varlen_func
|
|
|
|
use_flash_attn2 = True
|
|
except ImportError:
|
|
use_flash_attn2 = False
|
|
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
|
|
|
|
|
def llama_causal_lm_forward(
|
|
self: LlamaForCausalLM,
|
|
input_tokens_ids: torch.Tensor,
|
|
output_tensor: torch.Tensor,
|
|
inputmetadata: InputMetaData,
|
|
k_caches: List[torch.Tensor] = None,
|
|
v_caches: List[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""This function will replace the forward function of LlamaForCausalLM.
|
|
|
|
Args:
|
|
batch (BatchInfo): It stores the necessary input information for this inference.
|
|
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
|
|
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
|
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
|
"""
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
hidden_states = llama_model_forward(
|
|
self.model,
|
|
input_tokens_ids=input_tokens_ids,
|
|
output_tensor=output_tensor,
|
|
inputmetadata=inputmetadata,
|
|
k_caches=k_caches,
|
|
v_caches=v_caches,
|
|
use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
|
|
high_precision=inputmetadata.high_precision,
|
|
)
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
return logits
|
|
|
|
|
|
def llama_model_forward(
|
|
self: LlamaModel,
|
|
input_tokens_ids: torch.Tensor,
|
|
output_tensor: torch.Tensor,
|
|
inputmetadata: InputMetaData,
|
|
k_caches: List[torch.Tensor] = None,
|
|
v_caches: List[torch.Tensor] = None,
|
|
use_cuda_kernel: Optional[bool] = True,
|
|
high_precision: bool = False,
|
|
) -> torch.Tensor:
|
|
"""This function will replace the forward function of LlamaModel.
|
|
|
|
Args:
|
|
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
|
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
|
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
|
"""
|
|
block_tables = inputmetadata.block_tables
|
|
sequence_lengths = inputmetadata.sequence_lengths
|
|
kv_seq_len = inputmetadata.kv_seq_len
|
|
|
|
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
|
|
# during speculative-decoding (`q_len > 1`)
|
|
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
|
|
if inputmetadata.use_spec_dec and use_cuda_kernel:
|
|
use_cuda_kernel = False
|
|
logger.warning("CUDA kernel is disabled for speculative-decoding.")
|
|
|
|
hidden_states = self.embed_tokens(input_tokens_ids)
|
|
|
|
cu_seqlens = None
|
|
|
|
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
|
|
if inputmetadata.use_spec_dec:
|
|
# For speculative-decoding Prefill and Verifying Stage
|
|
if inputmetadata.is_prompts:
|
|
# output tensor shape is the same as normal Prefill Stage
|
|
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
|
|
else:
|
|
# the number of tokens to be verified in parallel plus the correct token in the last step
|
|
n_tokens = inputmetadata.num_tokens_to_verify + 1
|
|
assert n_tokens == hidden_states.size(0)
|
|
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
|
|
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
|
|
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
|
|
|
elif use_cuda_kernel:
|
|
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
|
|
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
|
|
|
hidden_dim = self._cos_cached.size(-1)
|
|
total_length = hidden_states.size(0)
|
|
cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device)
|
|
sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device)
|
|
inference_ops.get_cos_and_sin(
|
|
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
|
|
)
|
|
cos_sin = (cos, sin)
|
|
else:
|
|
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
|
|
|
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
|
|
|
norm_output = torch.empty_like(hidden_states)
|
|
tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None
|
|
residual = None
|
|
|
|
for layer_id, decoder_layer in enumerate(self.layers):
|
|
hidden_states, residual = decoder_layer(
|
|
hidden_states,
|
|
residual=residual,
|
|
block_tables=block_tables,
|
|
k_cache=k_caches[layer_id],
|
|
v_cache=v_caches[layer_id],
|
|
is_prompts=inputmetadata.is_prompts,
|
|
is_verifier=inputmetadata.use_spec_dec,
|
|
tokens_to_verify=tokens_to_verify,
|
|
sequence_lengths=sequence_lengths,
|
|
cos_sin=cos_sin,
|
|
fd_inter_tensor=inputmetadata.fd_inter_tensor,
|
|
kv_seq_len=kv_seq_len,
|
|
output_tensor=output_tensor,
|
|
norm_output=norm_output,
|
|
sm_scale=sm_scale,
|
|
use_cuda_kernel=use_cuda_kernel,
|
|
cu_seqlens=cu_seqlens,
|
|
high_precision=high_precision,
|
|
)
|
|
|
|
if inputmetadata.is_prompts:
|
|
seq_len_cumsum = sequence_lengths.cumsum(dim=0)
|
|
hidden_states = hidden_states[seq_len_cumsum - 1].contiguous()
|
|
residual = residual[seq_len_cumsum - 1].contiguous()
|
|
norm_output = torch.empty_like(hidden_states)
|
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
|
|
|
return hidden_states
|
|
|
|
|
|
def llama_decoder_layer_forward(
|
|
self: LlamaDecoderLayer,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
sequence_lengths: torch.Tensor,
|
|
cos_sin: Tuple[torch.Tensor],
|
|
fd_inter_tensor: FDIntermTensors,
|
|
is_prompts: bool = True,
|
|
is_verifier: bool = False,
|
|
tokens_to_verify: int = None,
|
|
kv_seq_len: int = 0,
|
|
output_tensor: torch.Tensor = None,
|
|
norm_output: torch.Tensor = None,
|
|
sm_scale: int = None,
|
|
use_cuda_kernel: bool = True,
|
|
cu_seqlens: torch.Tensor = None,
|
|
high_precision: bool = False,
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
"""This function will replace the forward function of LlamaDecoderLayer.
|
|
|
|
Args:
|
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
|
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
|
|
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
|
storing mapping of token_position_id -> block_id.
|
|
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
|
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
|
sequence_lengths (torch.Tensor): Holding the sequence length of each sequence.
|
|
cos_sin (Tuple[torch.Tensor]): Holding cos and sin.
|
|
fd_inter_tensor (FDIntermTensors): Holding tensors used for
|
|
storing intermediate values in flash-decoding.
|
|
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
|
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
|
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
|
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
|
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
|
"""
|
|
|
|
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
|
# Self Attention
|
|
hidden_states = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
block_tables=block_tables,
|
|
k_cache=k_cache,
|
|
v_cache=v_cache,
|
|
is_prompts=is_prompts,
|
|
is_verifier=is_verifier,
|
|
tokens_to_verify=tokens_to_verify,
|
|
sequence_lengths=sequence_lengths,
|
|
cos_sin=cos_sin,
|
|
fd_inter_tensor=fd_inter_tensor,
|
|
kv_seq_len=kv_seq_len,
|
|
output_tensor=output_tensor,
|
|
sm_scale=sm_scale,
|
|
use_cuda_kernel=use_cuda_kernel,
|
|
cu_seqlens=cu_seqlens,
|
|
high_precision=high_precision,
|
|
)
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
return hidden_states, residual
|
|
|
|
|
|
def llama_rmsnorm_forward(
|
|
self: LlamaRMSNorm,
|
|
hidden_states: torch.Tensor,
|
|
norm_output: torch.Tensor,
|
|
residual: torch.Tensor = None,
|
|
use_cuda_kernel: bool = True,
|
|
):
|
|
if use_cuda_kernel:
|
|
if residual is not None:
|
|
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
|
|
return hidden_states, residual
|
|
|
|
if norm_output is None:
|
|
norm_output = torch.empty_like(hidden_states)
|
|
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
|
|
return norm_output, hidden_states
|
|
else:
|
|
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
|
|
|
|
|
class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
|
def __init__(
|
|
self,
|
|
config: LlamaConfig,
|
|
mlp_gproj_w: torch.Tensor = None,
|
|
mlp_uproj_w: torch.Tensor = None,
|
|
mlp_dproj: ParallelModule = None,
|
|
process_group: ProcessGroup = None,
|
|
):
|
|
"""A Unified Layer for
|
|
|
|
Args:
|
|
config (LlamaConfig): Holding the Llama model config.
|
|
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
|
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
|
mlp_dproj (Linear1D_Row, optional): The Linear1D_Row mlp_dproj weight. Defaults to None.
|
|
"""
|
|
ParallelModule.__init__(self)
|
|
self.config = config
|
|
assert is_distributed_tensor(
|
|
mlp_gproj_w
|
|
), "mlp_gproj_w must be dtensor so we could get the layout of the weight"
|
|
self.helper_layout = (
|
|
mlp_gproj_w.dist_layout
|
|
) # NOTE this is a hack for the right load/shard of gate_up_weight(used in _load_from_state_dict)
|
|
self.gate_up_weight = nn.Parameter(
|
|
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
|
|
)
|
|
self.down_proj = mlp_dproj
|
|
self.process_group = process_group
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: LlamaMLP, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
|
|
|
|
Args:
|
|
module (LlamaMLP): The origin LlamaMLP layer.
|
|
"""
|
|
|
|
config = module.config
|
|
|
|
mlp_gproj_w = module.gate_proj.weight
|
|
assert is_distributed_tensor(
|
|
module.gate_proj.weight
|
|
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
|
|
mlp_uproj_w = module.up_proj.weight
|
|
mlp_dproj = module.down_proj
|
|
|
|
mlp_layer = NopadLlamaMLP(
|
|
config=config,
|
|
mlp_gproj_w=mlp_gproj_w,
|
|
mlp_uproj_w=mlp_uproj_w,
|
|
mlp_dproj=mlp_dproj,
|
|
process_group=process_group,
|
|
)
|
|
|
|
return mlp_layer
|
|
|
|
def _load_from_state_dict(
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
|
|
|
|
for hook in self._load_state_dict_pre_hooks.values():
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
|
|
key = "gate_up_weight"
|
|
k1 = "gate_proj.weight"
|
|
k2 = "up_proj.weight"
|
|
|
|
gate_w = state_dict[prefix + k1]
|
|
up_w = state_dict[prefix + k2]
|
|
|
|
device_mesh = self.helper_layout.device_mesh
|
|
sharding_spec = self.helper_layout.sharding_spec
|
|
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
|
|
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
|
|
|
|
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
|
|
|
|
input_param = nn.Parameter(
|
|
gate_up_w
|
|
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
|
param = local_state[key]
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
param.copy_(input_param)
|
|
except Exception as ex:
|
|
error_msgs.append(
|
|
'While copying the parameter named "{}", '
|
|
"whose dimensions in the model are {} and "
|
|
"whose dimensions in the checkpoint are {}, "
|
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
|
)
|
|
|
|
strict = False # to avoid unexpected_keys
|
|
super()._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
|
"""
|
|
hidden_states = hidden_states.expand(2, -1, -1)
|
|
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
|
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
|
|
|
return self.down_proj(act_out)
|
|
|
|
def extra_repr(self) -> str:
|
|
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
|
|
|
|
|
|
class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|
def __init__(
|
|
self,
|
|
config: LlamaConfig,
|
|
layer_idx: Optional[int] = None,
|
|
attn_qproj_w: torch.Tensor = None,
|
|
attn_kproj_w: torch.Tensor = None,
|
|
attn_vproj_w: torch.Tensor = None,
|
|
attn_oproj: ParallelModule = None,
|
|
process_group: ProcessGroup = None,
|
|
num_heads: int = None,
|
|
hidden_size: int = None,
|
|
num_key_value_heads: int = None,
|
|
):
|
|
"""This layer will replace the LlamaAttention.
|
|
|
|
Args:
|
|
config (LlamaConfig): Holding the Llama model config.
|
|
layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
|
|
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
|
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
|
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
|
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
|
|
"""
|
|
ParallelModule.__init__(self)
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
|
|
self.o_proj = attn_oproj
|
|
self.process_group = process_group
|
|
|
|
self.attention_dropout = config.attention_dropout
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
self.is_causal = True
|
|
|
|
if self.num_heads == self.num_key_value_heads:
|
|
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
|
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
|
self.helper_layout = (
|
|
attn_qproj_w.dist_layout
|
|
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
|
else:
|
|
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
|
|
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
|
|
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: LlamaAttention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
|
|
|
|
Args:
|
|
module (LlamaAttention): The origin LlamaAttention layer.
|
|
"""
|
|
|
|
config = module.config
|
|
layer_idx = module.layer_idx
|
|
|
|
attn_qproj_w = module.q_proj.weight
|
|
attn_kproj_w = module.k_proj.weight
|
|
attn_vproj_w = module.v_proj.weight
|
|
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
|
|
attn_oproj = module.o_proj
|
|
|
|
attn_layer = NopadLlamaAttention(
|
|
config=config,
|
|
layer_idx=layer_idx,
|
|
attn_qproj_w=attn_qproj_w,
|
|
attn_kproj_w=attn_kproj_w,
|
|
attn_vproj_w=attn_vproj_w,
|
|
attn_oproj=attn_oproj,
|
|
process_group=process_group,
|
|
num_heads=module.num_heads,
|
|
hidden_size=module.hidden_size,
|
|
num_key_value_heads=module.num_key_value_heads,
|
|
)
|
|
|
|
return attn_layer
|
|
|
|
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
sequence_lengths: torch.Tensor,
|
|
cos_sin: Tuple[torch.Tensor],
|
|
fd_inter_tensor: FDIntermTensors,
|
|
is_prompts: bool = True,
|
|
is_verifier: bool = False,
|
|
tokens_to_verify: int = None,
|
|
kv_seq_len: int = 0,
|
|
output_tensor: torch.Tensor = None,
|
|
sm_scale: int = None,
|
|
use_cuda_kernel: bool = True,
|
|
cu_seqlens: torch.Tensor = None,
|
|
high_precision: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""
|
|
Args:
|
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
|
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
|
storing mapping of token_position_id -> block_id.
|
|
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
|
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
|
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
|
|
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
|
|
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
|
storing intermediate values in flash-decoding.
|
|
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
|
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
|
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
|
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
|
"""
|
|
|
|
token_nums = hidden_states.size(0)
|
|
|
|
if self.num_heads != self.num_key_value_heads:
|
|
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
|
|
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
|
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
|
else:
|
|
# fused qkv
|
|
hidden_states = hidden_states.expand(3, -1, -1)
|
|
query_states, key_states, value_states = (
|
|
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
|
)
|
|
|
|
block_size = k_cache.size(-2)
|
|
|
|
if is_prompts:
|
|
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
|
# flash attn 2 currently only supports FP16/BF16.
|
|
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
|
inference_ops.context_kv_cache_memcpy(
|
|
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
|
)
|
|
|
|
attn_output = flash_attn_varlen_func(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=kv_seq_len,
|
|
max_seqlen_k=kv_seq_len,
|
|
dropout_p=0.0,
|
|
softmax_scale=sm_scale,
|
|
causal=True,
|
|
)
|
|
attn_output = attn_output.view(token_nums, -1)
|
|
else:
|
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
attn_output = context_attention_unpadded(
|
|
q=query_states,
|
|
k=key_states,
|
|
v=value_states,
|
|
k_cache=k_cache,
|
|
v_cache=v_cache,
|
|
context_lengths=sequence_lengths,
|
|
block_tables=block_tables,
|
|
block_size=block_size,
|
|
output=output_tensor,
|
|
max_seq_len=kv_seq_len,
|
|
sm_scale=sm_scale,
|
|
use_new_kcache_layout=use_cuda_kernel,
|
|
)
|
|
else:
|
|
q_len = tokens_to_verify + 1 if is_verifier else 1
|
|
|
|
if use_cuda_kernel:
|
|
inference_ops.rotary_embedding_and_cache_copy(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
cos_sin[0],
|
|
cos_sin[1],
|
|
k_cache,
|
|
v_cache,
|
|
sequence_lengths,
|
|
block_tables,
|
|
high_precision,
|
|
)
|
|
inference_ops.flash_decoding_attention(
|
|
output_tensor,
|
|
query_states,
|
|
k_cache,
|
|
v_cache,
|
|
sequence_lengths,
|
|
block_tables,
|
|
block_size,
|
|
kv_seq_len,
|
|
fd_inter_tensor.mid_output,
|
|
fd_inter_tensor.mid_output_lse,
|
|
None,
|
|
sm_scale,
|
|
)
|
|
attn_output = output_tensor
|
|
else:
|
|
if is_verifier:
|
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
copy_k_to_blocked_cache(
|
|
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
|
)
|
|
copy_k_to_blocked_cache(
|
|
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
|
)
|
|
else:
|
|
decoding_fused_rotary_embedding(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
cos_sin[0],
|
|
cos_sin[1],
|
|
k_cache,
|
|
v_cache,
|
|
block_tables,
|
|
sequence_lengths,
|
|
)
|
|
attn_output = flash_decoding_attention(
|
|
q=query_states,
|
|
k_cache=k_cache,
|
|
v_cache=v_cache,
|
|
kv_seq_len=sequence_lengths,
|
|
block_tables=block_tables,
|
|
block_size=block_size,
|
|
max_seq_len_in_batch=kv_seq_len,
|
|
output=output_tensor,
|
|
mid_output=fd_inter_tensor.mid_output,
|
|
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
|
sm_scale=sm_scale,
|
|
kv_group_num=self.num_key_value_groups,
|
|
q_len=q_len,
|
|
)
|
|
|
|
attn_output = attn_output.view(-1, self.hidden_size)
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output
|
|
|
|
def _load_from_state_dict(
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
|
for hook in self._load_state_dict_pre_hooks.values():
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
|
|
key = "qkv_weight"
|
|
k1 = "q_proj.weight"
|
|
k2 = "k_proj.weight"
|
|
k3 = "v_proj.weight"
|
|
q_w = state_dict[prefix + k1]
|
|
k_w = state_dict[prefix + k2]
|
|
v_w = state_dict[prefix + k3]
|
|
|
|
device_mesh = self.helper_layout.device_mesh
|
|
sharding_spec = self.helper_layout.sharding_spec
|
|
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
|
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
|
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
|
|
|
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
|
|
|
input_param = nn.Parameter(
|
|
qkv_w
|
|
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
|
|
|
param = local_state[key]
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
param.copy_(input_param)
|
|
except Exception as ex:
|
|
error_msgs.append(
|
|
'While copying the parameter named "{}", '
|
|
"whose dimensions in the model are {} and "
|
|
"whose dimensions in the checkpoint are {}, "
|
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
|
)
|
|
|
|
strict = False # to avoid unexpected_keys
|
|
super()._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
)
|
|
|
|
def extra_repr(self) -> str:
|
|
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|