|
|
|
@ -6,7 +6,12 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
|
|
|
|
|
|
|
|
|
from colossalai.inference.modeling.layers.attention import PagedAttention
|
|
|
|
|
from colossalai.inference.struct import BatchInfo
|
|
|
|
|
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
|
|
|
|
|
from colossalai.kernel.triton import (
|
|
|
|
|
context_attention_unpadded,
|
|
|
|
|
copy_kv_to_blocked_cache,
|
|
|
|
|
flash_decoding_attention,
|
|
|
|
|
rotary_embedding,
|
|
|
|
|
)
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
|
|
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
|
|
|
@ -72,9 +77,10 @@ def llama_model_forward(
|
|
|
|
|
attention_mask = batch.get_attn_mask(padding_id)
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
# TODO After the nopad version is implemented, we will use the following code to get sequence_lengths.
|
|
|
|
|
# sequence_lengths = batch.get_sequence_lengths()
|
|
|
|
|
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
|
|
if HAS_TRITON:
|
|
|
|
|
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
|
|
else:
|
|
|
|
|
sequence_lengths = batch.get_sequence_lengths()
|
|
|
|
|
else:
|
|
|
|
|
sequence_lengths = batch.get_sequence_lengths()
|
|
|
|
|
|
|
|
|
@ -96,6 +102,8 @@ def llama_model_forward(
|
|
|
|
|
|
|
|
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
|
|
|
|
|
|
|
|
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
|
|
|
|
|
|
|
|
|
|
for layer_id, decoder_layer in enumerate(self.layers):
|
|
|
|
|
hidden_states = decoder_layer(
|
|
|
|
|
hidden_states,
|
|
|
|
@ -107,6 +115,7 @@ def llama_model_forward(
|
|
|
|
|
sequence_lengths=sequence_lengths,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
kv_seq_len=kv_seq_len,
|
|
|
|
|
cos_sin=cos_sin,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
@ -125,6 +134,7 @@ def llama_decoder_layer_forward(
|
|
|
|
|
sequence_lengths: int = None,
|
|
|
|
|
attention_mask: torch.Tensor = None,
|
|
|
|
|
kv_seq_len: int = 0,
|
|
|
|
|
cos_sin: Tuple[torch.Tensor] = None,
|
|
|
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
|
@ -140,6 +150,7 @@ def llama_decoder_layer_forward(
|
|
|
|
|
sequence_lengths=sequence_lengths,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
kv_seq_len=kv_seq_len,
|
|
|
|
|
cos_sin=cos_sin,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
@ -166,27 +177,16 @@ def llama_attn_forward(
|
|
|
|
|
sequence_lengths: torch.Tensor = None,
|
|
|
|
|
attention_mask: torch.Tensor = None,
|
|
|
|
|
kv_seq_len: int = 0,
|
|
|
|
|
cos_sin: Tuple[torch.Tensor] = None,
|
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
|
|
|
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
kv_seq_len = max(sequence_lengths).item()
|
|
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
|
|
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
|
|
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
|
|
|
|
|
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
|
|
|
key_states = key_states.transpose(1, 2)
|
|
|
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
_, _, _, block_size = k_cache.shape
|
|
|
|
|
|
|
|
|
|
if is_prompts:
|
|
|
|
|
if HAS_TRITON:
|
|
|
|
|
if HAS_TRITON:
|
|
|
|
|
if is_prompts:
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
query_states, key_states, value_states, indices = unpading_input(
|
|
|
|
|
query_states, key_states, value_states, attention_mask
|
|
|
|
@ -195,29 +195,44 @@ def llama_attn_forward(
|
|
|
|
|
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
|
|
|
|
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
|
|
|
|
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
|
|
|
|
else:
|
|
|
|
|
query_states = query_states.squeeze(dim=1)
|
|
|
|
|
key_states = key_states.squeeze(dim=1)
|
|
|
|
|
value_states = value_states.squeeze(dim=1)
|
|
|
|
|
|
|
|
|
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
|
|
|
|
|
|
|
|
_, _, _, block_size = k_cache.shape
|
|
|
|
|
|
|
|
|
|
if is_prompts:
|
|
|
|
|
attn_output = context_attention_unpadded(
|
|
|
|
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
|
|
|
|
)
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
|
|
|
|
else:
|
|
|
|
|
attn_output = PagedAttention.pad_context_forward(
|
|
|
|
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if HAS_TRITON:
|
|
|
|
|
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
|
|
|
|
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
|
|
|
|
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
|
|
|
|
|
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
|
|
|
|
|
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
|
|
|
|
|
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
|
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
|
|
|
attn_output = flash_decoding_attention(
|
|
|
|
|
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
|
|
|
|
)
|
|
|
|
|
attn_output = attn_output.squeeze(1)
|
|
|
|
|
else:
|
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
|
|
|
key_states = key_states.transpose(1, 2)
|
|
|
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
|
|
|
key_states = key_states.transpose(1, 2)
|
|
|
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
if is_prompts:
|
|
|
|
|
attn_output = PagedAttention.pad_context_forward(
|
|
|
|
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
attn_output = PagedAttention.pad_decoding_forward(
|
|
|
|
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
|
|
|
@ -232,6 +247,15 @@ def llama_attn_forward(
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""Generate padding position_id through attention mask.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
|
|
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: The padding position_id.
|
|
|
|
|
"""
|
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
|
|
return position_ids
|
|
|
|
@ -239,9 +263,34 @@ def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
|
|
|
|
|
"""Convert padding input to nopad input.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
q (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
|
|
|
|
k (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
|
|
|
|
v (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
|
|
|
|
attention_mask (torch.Tensor): [batch_size, sequence_length]
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[torch.Tensor]: The unpad q, k, v and The index of valid data in each batch.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
|
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
|
|
|
|
|
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
|
|
|
|
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
|
|
|
|
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
|
|
|
|
return (q, k, v, indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
|
|
|
|
|
if is_prompts:
|
|
|
|
|
index_arrays = [torch.arange(length) for length in lengths]
|
|
|
|
|
else:
|
|
|
|
|
index_arrays = [(length - 1).view(-1) for length in lengths]
|
|
|
|
|
indices = torch.cat(index_arrays, dim=-1)
|
|
|
|
|
cos_output = cos_cache[indices].to(dtype=dtype)
|
|
|
|
|
sin_output = sin_cache[indices].to(dtype=dtype)
|
|
|
|
|
|
|
|
|
|
return (cos_output, sin_output)
|
|
|
|
|