mirror of https://github.com/hpcaitech/ColossalAI
[inference] Adapted to Rotary Embedding and RMS Norm (#5283)
* adapted to rotary_embedding * adapted to nopad rms norm * fix bugs in benchmark * fix flash_decoding.pypull/5297/head
parent
6e487e7d3c
commit
bfff9254ac
|
@ -6,7 +6,12 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
|||
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_kv_to_blocked_cache,
|
||||
flash_decoding_attention,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||
|
@ -72,9 +77,10 @@ def llama_model_forward(
|
|||
attention_mask = batch.get_attn_mask(padding_id)
|
||||
|
||||
if attention_mask is not None:
|
||||
# TODO After the nopad version is implemented, we will use the following code to get sequence_lengths.
|
||||
# sequence_lengths = batch.get_sequence_lengths()
|
||||
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
if HAS_TRITON:
|
||||
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
else:
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
else:
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
|
@ -96,6 +102,8 @@ def llama_model_forward(
|
|||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
|
@ -107,6 +115,7 @@ def llama_model_forward(
|
|||
sequence_lengths=sequence_lengths,
|
||||
attention_mask=attention_mask,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
@ -125,6 +134,7 @@ def llama_decoder_layer_forward(
|
|||
sequence_lengths: int = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
|
@ -140,6 +150,7 @@ def llama_decoder_layer_forward(
|
|||
sequence_lengths=sequence_lengths,
|
||||
attention_mask=attention_mask,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
@ -166,27 +177,16 @@ def llama_attn_forward(
|
|||
sequence_lengths: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
kv_seq_len = max(sequence_lengths).item()
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
_, _, _, block_size = k_cache.shape
|
||||
|
||||
if is_prompts:
|
||||
if HAS_TRITON:
|
||||
if HAS_TRITON:
|
||||
if is_prompts:
|
||||
if attention_mask is not None:
|
||||
query_states, key_states, value_states, indices = unpading_input(
|
||||
query_states, key_states, value_states, attention_mask
|
||||
|
@ -195,29 +195,44 @@ def llama_attn_forward(
|
|||
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
else:
|
||||
query_states = query_states.squeeze(dim=1)
|
||||
key_states = key_states.squeeze(dim=1)
|
||||
value_states = value_states.squeeze(dim=1)
|
||||
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
|
||||
_, _, _, block_size = k_cache.shape
|
||||
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||
else:
|
||||
attn_output = PagedAttention.pad_context_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
else:
|
||||
if HAS_TRITON:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
|
||||
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
|
||||
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
|
||||
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
attn_output = flash_decoding_attention(
|
||||
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
attn_output = attn_output.squeeze(1)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if is_prompts:
|
||||
attn_output = PagedAttention.pad_context_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
else:
|
||||
attn_output = PagedAttention.pad_decoding_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
|
@ -232,6 +247,15 @@ def llama_attn_forward(
|
|||
|
||||
@torch.no_grad()
|
||||
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate padding position_id through attention mask.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The padding position_id.
|
||||
"""
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
return position_ids
|
||||
|
@ -239,9 +263,34 @@ def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
|||
|
||||
@torch.no_grad()
|
||||
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
|
||||
"""Convert padding input to nopad input.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
||||
k (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
||||
v (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
||||
attention_mask (torch.Tensor): [batch_size, sequence_length]
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: The unpad q, k, v and The index of valid data in each batch.
|
||||
|
||||
"""
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
|
||||
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||
return (q, k, v, indices)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
|
||||
if is_prompts:
|
||||
index_arrays = [torch.arange(length) for length in lengths]
|
||||
else:
|
||||
index_arrays = [(length - 1).view(-1) for length in lengths]
|
||||
indices = torch.cat(index_arrays, dim=-1)
|
||||
cos_output = cos_cache[indices].to(dtype=dtype)
|
||||
sin_output = sin_cache[indices].to(dtype=dtype)
|
||||
|
||||
return (cos_output, sin_output)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
|
@ -15,11 +17,31 @@ from colossalai.inference.modeling.models.llama import (
|
|||
llama_decoder_layer_forward,
|
||||
llama_model_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def __init__(self) -> None:
|
||||
|
@ -162,4 +184,18 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
|
||||
)
|
||||
|
||||
infer_forward = None
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
|
||||
if infer_forward is not None:
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
init_to_get_rotary(self.model.model)
|
||||
return self.model
|
||||
|
|
|
@ -18,7 +18,6 @@ def _flash_decoding_fwd_kernel(
|
|||
kv_seq_len, # [batch_size]
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_ql,
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
|
@ -199,7 +198,7 @@ def flash_decoding_attention(
|
|||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
|
||||
q (torch.Tensor): [bsz, num_heads, head_dim]
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
kv_seq_len (torch.Tensor): [batch_size]
|
||||
|
@ -216,7 +215,10 @@ def flash_decoding_attention(
|
|||
Returns:
|
||||
Output tensor with shape [bsz, num_heads, q_len, head_dim]
|
||||
"""
|
||||
bsz, num_heads, _, head_dim = q.shape
|
||||
if q.dim() == 3:
|
||||
bsz, num_heads, head_dim = q.shape
|
||||
else:
|
||||
raise ValueError(f"The query dim should be 3, but got {q.dim()}.")
|
||||
|
||||
assert head_dim in {32, 64, 128, 256}
|
||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||
|
@ -262,7 +264,6 @@ def flash_decoding_attention(
|
|||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
|
|
|
@ -53,16 +53,23 @@ def copy_kv_to_blocked_cache(
|
|||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
||||
Parameters:
|
||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
|
||||
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
"""
|
||||
assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)"
|
||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
||||
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
bsz, _, num_kv_heads, head_dim = k.shape
|
||||
if k.dim() == 4:
|
||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
||||
bsz, _, num_kv_heads, head_dim = k.shape
|
||||
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
|
||||
k = k.squeeze(dim=1)
|
||||
elif k.dim() == 3:
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
else:
|
||||
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")
|
||||
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
|
||||
|
@ -71,8 +78,6 @@ def copy_kv_to_blocked_cache(
|
|||
|
||||
# Modify if the shape of kv cahce is changed.
|
||||
block_size = k_cache.size(-1)
|
||||
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
|
||||
k = k.squeeze(dim=1)
|
||||
|
||||
num_warps = 8 if head_dim > 128 else 4
|
||||
|
||||
|
|
|
@ -95,10 +95,13 @@ def benchmark_inference(args):
|
|||
|
||||
if args.dtype == "fp16":
|
||||
model = model.half()
|
||||
elif args.dtype == "fp16":
|
||||
elif args.dtype == "bf16":
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
mbsz = args.mbsz
|
||||
if args.continous_batching:
|
||||
mbsz = args.mbsz
|
||||
else:
|
||||
mbsz = args.batch_size
|
||||
if args.mode == "caiinference":
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
|
@ -205,5 +208,8 @@ if __name__ == "__main__":
|
|||
choices=["caiinference", "transformers"],
|
||||
help="decide which inference framework to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-cb", "--continous_batching", default=False, action="store_true", help="enable continous batching"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
benchmark(args)
|
||||
|
|
Loading…
Reference in New Issue