[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)

* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5629/head
Yuanheng Zhao 2024-04-23 13:09:55 +08:00 committed by GitHub
parent ccf72797e3
commit 5d4c1fe8f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 183 additions and 194 deletions

View File

@ -386,6 +386,7 @@ class BatchBucket:
seq_id, seq = next(seqs_iter) seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence" assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens] seq.output_token_id = seq.output_token_id[:-n_tokens]
seq.revoke_finished_status()
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:

View File

@ -518,7 +518,13 @@ class InferenceEngine:
""" """
with torch.inference_mode(): with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = [] output_seqs_list = []
total_tokens_list = [] total_tokens_list = []
@ -573,6 +579,7 @@ class InferenceEngine:
request_ids: List[int] = None, request_ids: List[int] = None,
prompts: List[str] = None, prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None: ) -> None:
""" """
Add requests. Add requests.
@ -629,6 +636,13 @@ class InferenceEngine:
else: else:
prompt = prompts[i] prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
sequence = Sequence( sequence = Sequence(
request_id, request_id,
prompt, prompt,
@ -637,7 +651,7 @@ class InferenceEngine:
None, None,
self.tokenizer.eos_token_id, self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id, self.tokenizer.pad_token_id,
self.inference_config.max_output_len, max_output_len=max_new_tokens,
) )
self.request_handler.add_sequence(sequence) self.request_handler.add_sequence(sequence)

View File

@ -314,10 +314,11 @@ class RequestHandler:
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig): def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li: for seq in batch.seqs_li:
if ( max_length = generation_config.max_length
seq.output_token_id[-1] == generation_config.eos_token_id max_new_tokens = generation_config.max_new_tokens
or seq.output_len >= generation_config.max_length if max_length is not None:
): max_new_tokens = max_length - seq.input_len
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished() seq.mark_finished()
def check_unfinished_seqs(self) -> bool: def check_unfinished_seqs(self) -> bool:

View File

@ -38,7 +38,7 @@ class KVCacheManager:
The block table after block allocation might be: The block table after block allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 | | 0 | 1 | 2 | -1 | -1 | -1 |
Then the logical blocks with id 0, 1, and 2, are allocated for this sequence, Then the logical blocks with id 0, 1, and 2, are allocated for this sequence,
and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer, and the physical caches, each with size of `block_size * kv_head_num * head_size * elem_size` for a single layer,
corresponding to these blocks will be used to read/write KV Caches in kernels. corresponding to these blocks will be used to read/write KV Caches in kernels.
For a batch of sequences, the block tables after allocation might be: For a batch of sequences, the block tables after allocation might be:
@ -64,9 +64,12 @@ class KVCacheManager:
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads") self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" assert (
self.head_num //= self.tp_size self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
self.kv_head_num //= self.tp_size
self.beam_width = config.beam_width self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len self.max_input_length = config.max_input_len
@ -80,9 +83,8 @@ class KVCacheManager:
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation # Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size) alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
# if verbose: self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
# self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape) self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = ( self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes self.elem_size_in_bytes
@ -90,9 +92,12 @@ class KVCacheManager:
* 2 * 2
* self.num_blocks * self.num_blocks
* self.block_size * self.block_size
* self.head_num * self.kv_head_num
* self.head_size * self.head_size
) )
self.logger.info(
f"Allocated {self.total_physical_cache_size_in_bytes / GIGABYTE:.2f} GB of KV cache on device {self.device}."
)
# Logical cache blocks allocation # Logical cache blocks allocation
self._available_blocks = self.num_blocks self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches()) self._cache_blocks = tuple(self._init_logical_caches())
@ -453,7 +458,7 @@ class KVCacheManager:
""" """
assert self._kv_caches is not None and len(self._kv_caches[0]) > 0 assert self._kv_caches is not None and len(self._kv_caches[0]) > 0
blocks = [] blocks = []
physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size physical_block_size = self.elem_size_in_bytes * self.block_size * self.kv_head_num * self.head_size
k_ptrs = [ k_ptrs = [
self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers) self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
] ]

View File

@ -447,9 +447,9 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
attn_qproj_w.dist_layout attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
else: else:
self.q_proj_weight = attn_qproj_w self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = attn_kproj_w self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = attn_vproj_w self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
@staticmethod @staticmethod
def from_native_module( def from_native_module(
@ -638,6 +638,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
mid_output=fd_inter_tensor.mid_output, mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse, mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale, sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len, q_len=q_len,
) )

View File

@ -117,6 +117,14 @@ class Sequence:
return False return False
def revoke_finished_status(self) -> None:
"""
Revoke the finished status of the sequence.
This is only used by speculative decoding for now.
"""
if RequestStatus.is_finished(self.status):
self.status = RequestStatus.RUNNING
def __hash__(self): def __hash__(self):
return hash(self.request_id) return hash(self.request_id)

View File

@ -36,97 +36,91 @@ def rotary_embedding_kernel(
cos_stride, cos_stride,
q_total_tokens, q_total_tokens,
Q_HEAD_NUM: tl.constexpr, Q_HEAD_NUM: tl.constexpr,
K_HEAD_NUM: tl.constexpr, KV_GROUP_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
BLOCK_HEAD: tl.constexpr, BLOCK_TOKENS: tl.constexpr, # token range length
BLOCK_TOKENS: tl.constexpr,
): ):
block_head_index = tl.program_id(0) cur_head_idx = tl.program_id(0)
block_token_index = tl.program_id(1) cur_token_block_idx = tl.program_id(1)
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
tokens_range = cur_token_block_idx * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
off_q0 = ( off_q0 = (
tokens_range[:, None, None] * q_token_stride tokens_range[:, None, None] * q_token_stride
+ head_range[None, :, None] * q_head_stride + cur_head_idx * q_head_stride
+ dim_range0[None, None, :] * head_dim_stride + dim_range0[None, None, :] * head_dim_stride
) )
off_q1 = ( off_q1 = (
tokens_range[:, None, None] * q_token_stride tokens_range[:, None, None] * q_token_stride
+ head_range[None, :, None] * q_head_stride + cur_head_idx * q_head_stride
+ dim_range1[None, None, :] * head_dim_stride + dim_range1[None, None, :] * head_dim_stride
) )
off_k0 = (
tokens_range[:, None, None] * k_token_stride
+ head_range[None, :, None] * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
tokens_range[:, None, None] * k_token_stride
+ head_range[None, :, None] * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
loaded_q0 = tl.load( loaded_q0 = tl.load(
q + off_q0, q + off_q0,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0, other=0.0,
) )
loaded_q1 = tl.load( loaded_q1 = tl.load(
q + off_q1, q + off_q1,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0, other=0.0,
) )
loaded_k0 = tl.load(
k + off_k0,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
loaded_k1 = tl.load(
k + off_k1,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :] out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :] out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
# concat
tl.store( tl.store(
q + off_q0, q + off_q0,
out_q0, out_q0,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
) )
tl.store( tl.store(
q + off_q1, q + off_q1,
out_q1, out_q1,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
k + off_k0,
out_k0,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
k + off_k1,
out_k1,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
) )
handle_k = cur_head_idx % KV_GROUP_NUM == 0
if handle_k:
k_head_idx = cur_head_idx // KV_GROUP_NUM
off_k0 = (
tokens_range[:, None, None] * k_token_stride
+ k_head_idx * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
tokens_range[:, None, None] * k_token_stride
+ k_head_idx * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
loaded_k0 = tl.load(
k + off_k0,
mask=(tokens_range[:, None, None] < q_total_tokens),
other=0.0,
)
loaded_k1 = tl.load(
k + off_k1,
mask=(tokens_range[:, None, None] < q_total_tokens),
other=0.0,
)
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
tl.store(
k + off_k0,
out_k0,
mask=(tokens_range[:, None, None] < q_total_tokens),
)
tl.store(
k + off_k1,
out_k1,
mask=(tokens_range[:, None, None] < q_total_tokens),
)
@triton.jit @triton.jit
def fused_rotary_embedding_kernel( def fused_rotary_embedding_kernel(
@ -405,108 +399,74 @@ def decoding_fused_rotary_embedding_kernel(
bts_stride, bts_stride,
btb_stride, btb_stride,
block_size, block_size,
Q_HEAD_NUM: tl.constexpr, KV_GROUP_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
): ):
block_head_index = tl.program_id(0) cur_head_idx = tl.program_id(0)
if block_head_index >= Q_HEAD_NUM: cur_token_idx = tl.program_id(1)
return
block_token_index = tl.program_id(1)
dim_range = tl.arange(0, HEAD_DIM)
dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
total_dim_range = tl.arange(0, HEAD_DIM)
q_off_base = block_token_index * q_token_stride + block_head_index * q_head_stride off_q = cur_token_idx * q_token_stride + cur_head_idx * q_head_stride
off_q0 = q_off_base + dim_range0 * head_dim_stride off_q0 = off_q + dim_range0 * head_dim_stride
off_q1 = q_off_base + dim_range1 * head_dim_stride off_q1 = off_q + dim_range1 * head_dim_stride
off_base = block_token_index * k_token_stride + block_head_index * k_head_stride
off_k0 = off_base + dim_range0 * head_dim_stride
off_k1 = off_base + dim_range1 * head_dim_stride
off_v = off_base + total_dim_range * head_dim_stride
loaded_q0 = tl.load(
q + off_q0,
)
loaded_q1 = tl.load(
q + off_q1,
)
loaded_k0 = tl.load(
k + off_k0,
)
loaded_k1 = tl.load(
k + off_k1,
)
loaded_v = tl.load(
v + off_v,
)
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
loaded_q0 = tl.load(q + off_q0)
loaded_q1 = tl.load(q + off_q1)
off_cos_sin = cur_token_idx * cos_token_stride + dim_range0 * cos_stride
loaded_cos = tl.load(cos + off_cos_sin) loaded_cos = tl.load(cos + off_cos_sin)
loaded_sin = tl.load(sin + off_cos_sin) loaded_sin = tl.load(sin + off_cos_sin)
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
tl.store(q + off_q0, out_q0)
tl.store(q + off_q1, out_q1)
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin handle_k = cur_head_idx % KV_GROUP_NUM == 0
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim if handle_k:
cur_k_head_idx = cur_head_idx // KV_GROUP_NUM
off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride
off_k0 = off_kv + dim_range0 * head_dim_stride
off_k1 = off_kv + dim_range1 * head_dim_stride
loaded_k0 = tl.load(k + off_k0)
loaded_k1 = tl.load(k + off_k1)
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1 out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos
last_block_idx = past_kv_seq_len // block_size # NOTE The precondition here is that it's only for unpadded inputs during decoding stage,
block_ids = tl.load(BLOCK_TABLES + block_token_index * bts_stride + last_block_idx * btb_stride) # and so that we could directly use the token index as the sequence index
offsets_in_last_block = past_kv_seq_len % block_size past_kv_seq_len = tl.load(context_lengths + cur_token_idx) - 1
k_range0 = ( last_block_idx = past_kv_seq_len // block_size
block_ids * cache_b_stride block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride)
+ block_head_index * cache_h_stride offsets_in_last_block = past_kv_seq_len % block_size
+ offsets_in_last_block * cache_bs_stride k_range0 = (
+ dim_range0 * cache_d_stride block_ids * cache_b_stride
) + cur_k_head_idx * cache_h_stride
k_range1 = ( + offsets_in_last_block * cache_bs_stride
block_ids * cache_b_stride + dim_range0 * cache_d_stride
+ block_head_index * cache_h_stride )
+ offsets_in_last_block * cache_bs_stride k_range1 = (
+ dim_range1 * cache_d_stride block_ids * cache_b_stride
) + cur_k_head_idx * cache_h_stride
v_range = ( + offsets_in_last_block * cache_bs_stride
block_ids * cache_b_stride + dim_range1 * cache_d_stride
+ block_head_index * cache_h_stride )
+ offsets_in_last_block * cache_bs_stride tl.store(k_cache + k_range0, out_k0)
+ total_dim_range * cache_d_stride tl.store(k_cache + k_range1, out_k1)
)
tl.store( off_v = off_kv + dim_range * head_dim_stride
v_cache + v_range, loaded_v = tl.load(v + off_v)
loaded_v, v_range = (
) block_ids * cache_b_stride
+ cur_k_head_idx * cache_h_stride
tl.store( + offsets_in_last_block * cache_bs_stride
k_cache + k_range0, + dim_range * cache_d_stride
out_k0, )
) tl.store(v_cache + v_range, loaded_v)
tl.store(
k_cache + k_range1,
out_k1,
)
# concat
tl.store(
q + off_q0,
out_q0,
)
tl.store(
q + off_q1,
out_q1,
)
def rotary_embedding( def rotary_embedding(
@ -521,7 +481,7 @@ def rotary_embedding(
""" """
Args: Args:
q: query tensor, [total_tokens, head_num, head_dim] q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim] k: key tensor, [total_tokens, kv_head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim] cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim] sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
@ -530,32 +490,26 @@ def rotary_embedding(
""" """
q_total_tokens, q_head_num, head_dim = q.shape q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 4 BLOCK_TOKENS = 4
if head_dim >= 1024: if head_dim >= 512:
num_warps = 32
elif head_dim >= 512:
num_warps = 16 num_warps = 16
elif head_dim >= 256: elif head_dim >= 256:
num_warps = 8 num_warps = 8
else: else:
num_warps = 4 num_warps = 4
q_token_stride = q.stride(0) k_head_num = k.size(1)
q_head_stride = q.stride(1) q_token_stride, q_head_stride, head_dim_stride = q.stride()
head_dim_stride = q.stride(2) k_token_stride, k_head_stride, _ = k.stride()
cos_token_stride, cos_stride = cos.stride()
k_token_stride = k.stride(0) assert q_head_num % k_head_num == 0
k_head_stride = k.stride(1) kv_group_num = q_head_num // k_head_num
k_head_num = q.shape[1]
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
if k_cache == None: if k_cache == None:
grid = lambda META: ( grid = lambda META: (
triton.cdiv(q_head_num, META["BLOCK_HEAD"]), q_head_num,
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
) )
rotary_embedding_kernel[grid]( rotary_embedding_kernel[grid](
@ -572,9 +526,8 @@ def rotary_embedding(
cos_stride, cos_stride,
q_total_tokens, q_total_tokens,
Q_HEAD_NUM=q_head_num, Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num, KV_GROUP_NUM=kv_group_num,
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_TOKENS=BLOCK_TOKENS, BLOCK_TOKENS=BLOCK_TOKENS,
num_warps=num_warps, num_warps=num_warps,
) )
@ -624,23 +577,21 @@ def decoding_fused_rotary_embedding(
""" """
Args: Args:
q: query tensor, [total_tokens, head_num, head_dim] q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim] k: key tensor, [total_tokens, kv_head_num, head_dim]
v: value tensor, [total tokens, head_num, head_dim] v: value tensor, [total tokens, kv_head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim] cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim] sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] k_cache (torch.Tensor): Blocked key cache. [num_blocks, kv_head_num, block_size, head_dim]
v_cache (torch.Tensor): Blocked value cache. [num_blocks, num_kv_heads, block_size, head_dim] v_cache (torch.Tensor): Blocked value cache. [num_blocks, kv_head_num, block_size, head_dim]
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz] kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence] block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
""" """
q_total_tokens, q_head_num, head_dim = q.shape q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) == v.size(0) assert q.size(0) == k.size(0) == v.size(0)
assert q.size(1) == k.size(1) == v.size(1) assert k.size(1) == v.size(1)
assert k_cache.size(-1) == v_cache.size(-1) assert k_cache.size(-1) == v_cache.size(-1)
if head_dim >= 1024: if head_dim >= 512:
num_warps = 32
elif head_dim >= 512:
num_warps = 16 num_warps = 16
elif head_dim >= 256: elif head_dim >= 256:
num_warps = 8 num_warps = 8
@ -653,10 +604,12 @@ def decoding_fused_rotary_embedding(
k_token_stride = k.stride(0) k_token_stride = k.stride(0)
k_head_stride = k.stride(1) k_head_stride = k.stride(1)
k_head_num = k.size(1)
kv_group_num = q_head_num // k_head_num
cos_token_stride = cos.stride(0) cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1) cos_stride = cos.stride(1)
grid = (triton.next_power_of_2(q_head_num), q_total_tokens) grid = (q_head_num, q_total_tokens)
decoding_fused_rotary_embedding_kernel[grid]( decoding_fused_rotary_embedding_kernel[grid](
q, q,
k, k,
@ -681,7 +634,7 @@ def decoding_fused_rotary_embedding(
block_tables.stride(0), block_tables.stride(0),
block_tables.stride(1), block_tables.stride(1),
k_cache.size(-2), k_cache.size(-2),
Q_HEAD_NUM=q_head_num, KV_GROUP_NUM=kv_group_num,
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
num_warps=num_warps, num_warps=num_warps,
) )

View File

@ -133,8 +133,9 @@ def check_spec_dec(num_layers, max_length):
assert not engine.use_spec_dec assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None assert engine.drafter is None and engine.drafter_model is None
max_new_tokens = max_length - dummy_inputs.size(1)
assert len(out) == 1 assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
# test GLIDE model # test GLIDE model
glide_config = GlideLlamaConfig( glide_config = GlideLlamaConfig(
@ -152,7 +153,7 @@ def check_spec_dec(num_layers, max_length):
engine.clear_spec_dec() engine.clear_spec_dec()
assert len(out) == 1 assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
@ -186,7 +187,7 @@ def test_tp_engine(prompt_template, do_sample):
@parameterize("num_layers", [1]) @parameterize("num_layers", [1])
@parameterize("max_length", [100]) @parameterize("max_length", [64])
def test_spec_dec(num_layers, max_length): def test_spec_dec(num_layers, max_length):
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length) spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)

View File

@ -151,6 +151,16 @@ def test_flash_decoding_attention(
numpy_allclose(out_ref, output, rtol=rtol, atol=atol) numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
try:
from vllm._C import ops as vllm_ops # noqa
HAS_VLLM = True
except ImportError:
HAS_VLLM = False
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) @pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) @pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32]) @pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@ -166,11 +176,6 @@ def test_vllm_flash_decoding_attention(
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
try:
from vllm._C import ops as vllm_ops
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ