[inference] Adapted to Rotary Embedding and RMS Norm (#5283)

* adapted to rotary_embedding

* adapted to nopad rms norm

* fix bugs in benchmark

* fix flash_decoding.py
pull/5297/head
yuehuayingxueluo 10 months ago committed by GitHub
parent 6e487e7d3c
commit bfff9254ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -6,7 +6,12 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_kv_to_blocked_cache,
flash_decoding_attention,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
@ -72,11 +77,12 @@ def llama_model_forward(
attention_mask = batch.get_attn_mask(padding_id)
if attention_mask is not None:
# TODO After the nopad version is implemented, we will use the following code to get sequence_lengths.
# sequence_lengths = batch.get_sequence_lengths()
if HAS_TRITON:
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
else:
sequence_lengths = batch.get_sequence_lengths()
else:
sequence_lengths = batch.get_sequence_lengths()
kv_seq_len = sequence_lengths.max().item()
@ -96,6 +102,8 @@ def llama_model_forward(
hidden_states = self.embed_tokens(input_ids)
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
@ -107,6 +115,7 @@ def llama_model_forward(
sequence_lengths=sequence_lengths,
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
)
hidden_states = self.norm(hidden_states)
@ -125,6 +134,7 @@ def llama_decoder_layer_forward(
sequence_lengths: int = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
@ -140,6 +150,7 @@ def llama_decoder_layer_forward(
sequence_lengths=sequence_lengths,
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
)
hidden_states = residual + hidden_states
@ -166,27 +177,16 @@ def llama_attn_forward(
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = max(sequence_lengths).item()
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
_, _, _, block_size = k_cache.shape
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
if is_prompts:
if HAS_TRITON:
if is_prompts:
if attention_mask is not None:
query_states, key_states, value_states, indices = unpading_input(
query_states, key_states, value_states, attention_mask
@ -195,29 +195,44 @@ def llama_attn_forward(
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
else:
query_states = query_states.squeeze(dim=1)
key_states = key_states.squeeze(dim=1)
value_states = value_states.squeeze(dim=1)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape
if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
attn_output = PagedAttention.pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
if HAS_TRITON:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
query_states = query_states.transpose(1, 2)
attn_output = flash_decoding_attention(
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
attn_output = attn_output.squeeze(1)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if is_prompts:
attn_output = PagedAttention.pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
attn_output = PagedAttention.pad_decoding_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
@ -232,6 +247,15 @@ def llama_attn_forward(
@torch.no_grad()
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
"""Generate padding position_id through attention mask.
Args:
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Returns:
torch.Tensor: The padding position_id.
"""
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids
@ -239,9 +263,34 @@ def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
@torch.no_grad()
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
"""Convert padding input to nopad input.
Args:
q (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
k (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
v (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
attention_mask (torch.Tensor): [batch_size, sequence_length]
Returns:
Tuple[torch.Tensor]: The unpad q, k, v and The index of valid data in each batch.
"""
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
return (q, k, v, indices)
@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
index_arrays = [(length - 1).view(-1) for length in lengths]
indices = torch.cat(index_arrays, dim=-1)
cos_output = cos_cache[indices].to(dtype=dtype)
sin_output = sin_cache[indices].to(dtype=dtype)
return (cos_output, sin_output)

@ -1,11 +1,13 @@
from functools import partial
import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaSdpaAttention,
)
@ -15,11 +17,31 @@ from colossalai.inference.modeling.models.llama import (
llama_decoder_layer_forward,
llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
try:
from colossalai.kernel.triton import rms_layernorm
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
return _triton_rmsnorm_forward
else:
return None
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
@ -162,4 +184,18 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)
return policy
def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model

@ -18,7 +18,6 @@ def _flash_decoding_fwd_kernel(
kv_seq_len, # [batch_size]
stride_qt,
stride_qh,
stride_ql,
stride_qd,
stride_cacheb,
stride_cacheh,
@ -199,7 +198,7 @@ def flash_decoding_attention(
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
Args:
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
q (torch.Tensor): [bsz, num_heads, head_dim]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
kv_seq_len (torch.Tensor): [batch_size]
@ -216,7 +215,10 @@ def flash_decoding_attention(
Returns:
Output tensor with shape [bsz, num_heads, q_len, head_dim]
"""
bsz, num_heads, _, head_dim = q.shape
if q.dim() == 3:
bsz, num_heads, head_dim = q.shape
else:
raise ValueError(f"The query dim should be 3, but got {q.dim()}.")
assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
@ -262,7 +264,6 @@ def flash_decoding_attention(
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),

@ -53,16 +53,23 @@ def copy_kv_to_blocked_cache(
Copy keys or values to the blocked key/value cache during decoding stage.
Parameters:
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
"""
assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)"
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
if k.dim() == 4:
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
bsz, _, num_kv_heads, head_dim = k.shape
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
k = k.squeeze(dim=1)
elif k.dim() == 3:
bsz, num_kv_heads, head_dim = k.shape
else:
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
@ -71,8 +78,6 @@ def copy_kv_to_blocked_cache(
# Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-1)
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
k = k.squeeze(dim=1)
num_warps = 8 if head_dim > 128 else 4

@ -95,10 +95,13 @@ def benchmark_inference(args):
if args.dtype == "fp16":
model = model.half()
elif args.dtype == "fp16":
elif args.dtype == "bf16":
model = model.to(torch.bfloat16)
if args.continous_batching:
mbsz = args.mbsz
else:
mbsz = args.batch_size
if args.mode == "caiinference":
inference_config = InferenceConfig(
dtype=args.dtype,
@ -205,5 +208,8 @@ if __name__ == "__main__":
choices=["caiinference", "transformers"],
help="decide which inference framework to run",
)
parser.add_argument(
"-cb", "--continous_batching", default=False, action="store_true", help="enable continous batching"
)
args = parser.parse_args()
benchmark(args)

Loading…
Cancel
Save