From bfff9254ac8ca866673746ec47cfd2f87aab2b66 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 22 Jan 2024 10:55:34 +0800 Subject: [PATCH] [inference] Adapted to Rotary Embedding and RMS Norm (#5283) * adapted to rotary_embedding * adapted to nopad rms norm * fix bugs in benchmark * fix flash_decoding.py --- colossalai/inference/modeling/models/llama.py | 111 +++++++++++++----- colossalai/inference/modeling/policy/llama.py | 36 ++++++ colossalai/kernel/triton/flash_decoding.py | 9 +- colossalai/kernel/triton/kvcache_copy.py | 17 ++- examples/inference/benchmark_llama.py | 10 +- 5 files changed, 140 insertions(+), 43 deletions(-) diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index 09e95070a..ffd7d2292 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -6,7 +6,12 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.inference.struct import BatchInfo -from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention +from colossalai.kernel.triton import ( + context_attention_unpadded, + copy_kv_to_blocked_cache, + flash_decoding_attention, + rotary_embedding, +) from colossalai.logging import get_dist_logger from flash_attn.bert_padding import index_first_axis, pad_input # noqa @@ -72,9 +77,10 @@ def llama_model_forward( attention_mask = batch.get_attn_mask(padding_id) if attention_mask is not None: - # TODO After the nopad version is implemented, we will use the following code to get sequence_lengths. - # sequence_lengths = batch.get_sequence_lengths() - sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32) + if HAS_TRITON: + sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32) + else: + sequence_lengths = batch.get_sequence_lengths() else: sequence_lengths = batch.get_sequence_lengths() @@ -96,6 +102,8 @@ def llama_model_forward( hidden_states = self.embed_tokens(input_ids) + cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype) + for layer_id, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, @@ -107,6 +115,7 @@ def llama_model_forward( sequence_lengths=sequence_lengths, attention_mask=attention_mask, kv_seq_len=kv_seq_len, + cos_sin=cos_sin, ) hidden_states = self.norm(hidden_states) @@ -125,6 +134,7 @@ def llama_decoder_layer_forward( sequence_lengths: int = None, attention_mask: torch.Tensor = None, kv_seq_len: int = 0, + cos_sin: Tuple[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -140,6 +150,7 @@ def llama_decoder_layer_forward( sequence_lengths=sequence_lengths, attention_mask=attention_mask, kv_seq_len=kv_seq_len, + cos_sin=cos_sin, ) hidden_states = residual + hidden_states @@ -166,27 +177,16 @@ def llama_attn_forward( sequence_lengths: torch.Tensor = None, attention_mask: torch.Tensor = None, kv_seq_len: int = 0, + cos_sin: Tuple[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = max(sequence_lengths).item() - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - _, _, _, block_size = k_cache.shape - - if is_prompts: - if HAS_TRITON: + if HAS_TRITON: + if is_prompts: if attention_mask is not None: query_states, key_states, value_states, indices = unpading_input( query_states, key_states, value_states, attention_mask @@ -195,29 +195,44 @@ def llama_attn_forward( query_states = query_states.view(-1, self.num_heads, self.head_dim) key_states = key_states.view(-1, self.num_heads, self.head_dim) value_states = value_states.view(-1, self.num_heads, self.head_dim) + else: + query_states = query_states.squeeze(dim=1) + key_states = key_states.squeeze(dim=1) + value_states = value_states.squeeze(dim=1) + + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + _, _, _, block_size = k_cache.shape + + if is_prompts: attn_output = context_attention_unpadded( query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size ) if attention_mask is not None: attn_output = pad_input(attn_output, indices, bsz, q_len) else: - attn_output = PagedAttention.pad_context_forward( - query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask - ) - else: - if HAS_TRITON: copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) - # TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel - # in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output - # should be revised, as we could see in previous part of `llama_attn_forward` we still have - # redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent. - query_states = query_states.transpose(1, 2) attn_output = flash_decoding_attention( query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size ) attn_output = attn_output.squeeze(1) + else: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if is_prompts: + attn_output = PagedAttention.pad_context_forward( + query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask + ) else: attn_output = PagedAttention.pad_decoding_forward( query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask @@ -232,6 +247,15 @@ def llama_attn_forward( @torch.no_grad() def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor: + """Generate padding position_id through attention mask. + + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + Returns: + torch.Tensor: The padding position_id. + """ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) return position_ids @@ -239,9 +263,34 @@ def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor: @torch.no_grad() def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor): + """Convert padding input to nopad input. + + Args: + q (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] + k (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] + v (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] + attention_mask (torch.Tensor): [batch_size, sequence_length] + + Returns: + Tuple[torch.Tensor]: The unpad q, k, v and The index of valid data in each batch. + + """ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) return (q, k, v, indices) + + +@torch.no_grad() +def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype): + if is_prompts: + index_arrays = [torch.arange(length) for length in lengths] + else: + index_arrays = [(length - 1).view(-1) for length in lengths] + indices = torch.cat(index_arrays, dim=-1) + cos_output = cos_cache[indices].to(dtype=dtype) + sin_output = sin_cache[indices].to(dtype=dtype) + + return (cos_output, sin_output) diff --git a/colossalai/inference/modeling/policy/llama.py b/colossalai/inference/modeling/policy/llama.py index 6e4d074db..514c274ad 100644 --- a/colossalai/inference/modeling/policy/llama.py +++ b/colossalai/inference/modeling/policy/llama.py @@ -1,11 +1,13 @@ from functools import partial +import torch from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, LlamaFlashAttention2, LlamaForCausalLM, LlamaModel, + LlamaRMSNorm, LlamaSdpaAttention, ) @@ -15,11 +17,31 @@ from colossalai.inference.modeling.models.llama import ( llama_decoder_layer_forward, llama_model_forward, ) +from colossalai.inference.utils import init_to_get_rotary from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy +try: + from colossalai.kernel.triton import rms_layernorm + + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + class LlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: @@ -162,4 +184,18 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): description=method_replacement, policy=policy, target_key=LlamaSdpaAttention ) + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 15f1921ca..fec12f604 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -18,7 +18,6 @@ def _flash_decoding_fwd_kernel( kv_seq_len, # [batch_size] stride_qt, stride_qh, - stride_ql, stride_qd, stride_cacheb, stride_cacheh, @@ -199,7 +198,7 @@ def flash_decoding_attention( Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. Args: - q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim] + q (torch.Tensor): [bsz, num_heads, head_dim] k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] kv_seq_len (torch.Tensor): [batch_size] @@ -216,7 +215,10 @@ def flash_decoding_attention( Returns: Output tensor with shape [bsz, num_heads, q_len, head_dim] """ - bsz, num_heads, _, head_dim = q.shape + if q.dim() == 3: + bsz, num_heads, head_dim = q.shape + else: + raise ValueError(f"The query dim should be 3, but got {q.dim()}.") assert head_dim in {32, 64, 128, 256} assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( @@ -262,7 +264,6 @@ def flash_decoding_attention( q.stride(0), q.stride(1), q.stride(2), - q.stride(3), k_cache.stride(0), k_cache.stride(1), k_cache.stride(2), diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index 253b3912e..74f20c33b 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -53,16 +53,23 @@ def copy_kv_to_blocked_cache( Copy keys or values to the blocked key/value cache during decoding stage. Parameters: - - k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. + - k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. - k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache. - kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. - block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. """ - assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)" - assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)" assert k.size(-1) == k_cache.size(-2), "Incompatible head dim" assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." - bsz, _, num_kv_heads, head_dim = k.shape + if k.dim() == 4: + assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)" + bsz, _, num_kv_heads, head_dim = k.shape + # [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim] + k = k.squeeze(dim=1) + elif k.dim() == 3: + bsz, num_kv_heads, head_dim = k.shape + else: + raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.") + assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( f"Got incompatible batch size (number of seqs):\n" f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " @@ -71,8 +78,6 @@ def copy_kv_to_blocked_cache( # Modify if the shape of kv cahce is changed. block_size = k_cache.size(-1) - # [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim] - k = k.squeeze(dim=1) num_warps = 8 if head_dim > 128 else 4 diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 457546a7f..bcc426e3a 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -95,10 +95,13 @@ def benchmark_inference(args): if args.dtype == "fp16": model = model.half() - elif args.dtype == "fp16": + elif args.dtype == "bf16": model = model.to(torch.bfloat16) - mbsz = args.mbsz + if args.continous_batching: + mbsz = args.mbsz + else: + mbsz = args.batch_size if args.mode == "caiinference": inference_config = InferenceConfig( dtype=args.dtype, @@ -205,5 +208,8 @@ if __name__ == "__main__": choices=["caiinference", "transformers"], help="decide which inference framework to run", ) + parser.add_argument( + "-cb", "--continous_batching", default=False, action="store_true", help="enable continous batching" + ) args = parser.parse_args() benchmark(args)