[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)

* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5629/head
Yuanheng Zhao 7 months ago committed by GitHub
parent ccf72797e3
commit 5d4c1fe8f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -386,6 +386,7 @@ class BatchBucket:
seq_id, seq = next(seqs_iter) seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence" assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens] seq.output_token_id = seq.output_token_id[:-n_tokens]
seq.revoke_finished_status()
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:

@ -518,7 +518,13 @@ class InferenceEngine:
""" """
with torch.inference_mode(): with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = [] output_seqs_list = []
total_tokens_list = [] total_tokens_list = []
@ -573,6 +579,7 @@ class InferenceEngine:
request_ids: List[int] = None, request_ids: List[int] = None,
prompts: List[str] = None, prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None: ) -> None:
""" """
Add requests. Add requests.
@ -629,6 +636,13 @@ class InferenceEngine:
else: else:
prompt = prompts[i] prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
sequence = Sequence( sequence = Sequence(
request_id, request_id,
prompt, prompt,
@ -637,7 +651,7 @@ class InferenceEngine:
None, None,
self.tokenizer.eos_token_id, self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id, self.tokenizer.pad_token_id,
self.inference_config.max_output_len, max_output_len=max_new_tokens,
) )
self.request_handler.add_sequence(sequence) self.request_handler.add_sequence(sequence)

@ -314,10 +314,11 @@ class RequestHandler:
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig): def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li: for seq in batch.seqs_li:
if ( max_length = generation_config.max_length
seq.output_token_id[-1] == generation_config.eos_token_id max_new_tokens = generation_config.max_new_tokens
or seq.output_len >= generation_config.max_length if max_length is not None:
): max_new_tokens = max_length - seq.input_len
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished() seq.mark_finished()
def check_unfinished_seqs(self) -> bool: def check_unfinished_seqs(self) -> bool:

@ -38,7 +38,7 @@ class KVCacheManager:
The block table after block allocation might be: The block table after block allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 | | 0 | 1 | 2 | -1 | -1 | -1 |
Then the logical blocks with id 0, 1, and 2, are allocated for this sequence, Then the logical blocks with id 0, 1, and 2, are allocated for this sequence,
and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer, and the physical caches, each with size of `block_size * kv_head_num * head_size * elem_size` for a single layer,
corresponding to these blocks will be used to read/write KV Caches in kernels. corresponding to these blocks will be used to read/write KV Caches in kernels.
For a batch of sequences, the block tables after allocation might be: For a batch of sequences, the block tables after allocation might be:
@ -64,9 +64,12 @@ class KVCacheManager:
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads") self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" assert (
self.head_num //= self.tp_size self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
self.kv_head_num //= self.tp_size
self.beam_width = config.beam_width self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len self.max_input_length = config.max_input_len
@ -80,9 +83,8 @@ class KVCacheManager:
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation # Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size) alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
# if verbose: self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
# self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape) self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = ( self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes self.elem_size_in_bytes
@ -90,9 +92,12 @@ class KVCacheManager:
* 2 * 2
* self.num_blocks * self.num_blocks
* self.block_size * self.block_size
* self.head_num * self.kv_head_num
* self.head_size * self.head_size
) )
self.logger.info(
f"Allocated {self.total_physical_cache_size_in_bytes / GIGABYTE:.2f} GB of KV cache on device {self.device}."
)
# Logical cache blocks allocation # Logical cache blocks allocation
self._available_blocks = self.num_blocks self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches()) self._cache_blocks = tuple(self._init_logical_caches())
@ -453,7 +458,7 @@ class KVCacheManager:
""" """
assert self._kv_caches is not None and len(self._kv_caches[0]) > 0 assert self._kv_caches is not None and len(self._kv_caches[0]) > 0
blocks = [] blocks = []
physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size physical_block_size = self.elem_size_in_bytes * self.block_size * self.kv_head_num * self.head_size
k_ptrs = [ k_ptrs = [
self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers) self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
] ]

@ -447,9 +447,9 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
attn_qproj_w.dist_layout attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
else: else:
self.q_proj_weight = attn_qproj_w self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = attn_kproj_w self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = attn_vproj_w self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
@staticmethod @staticmethod
def from_native_module( def from_native_module(
@ -638,6 +638,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
mid_output=fd_inter_tensor.mid_output, mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse, mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale, sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len, q_len=q_len,
) )

@ -117,6 +117,14 @@ class Sequence:
return False return False
def revoke_finished_status(self) -> None:
"""
Revoke the finished status of the sequence.
This is only used by speculative decoding for now.
"""
if RequestStatus.is_finished(self.status):
self.status = RequestStatus.RUNNING
def __hash__(self): def __hash__(self):
return hash(self.request_id) return hash(self.request_id)

@ -36,97 +36,91 @@ def rotary_embedding_kernel(
cos_stride, cos_stride,
q_total_tokens, q_total_tokens,
Q_HEAD_NUM: tl.constexpr, Q_HEAD_NUM: tl.constexpr,
K_HEAD_NUM: tl.constexpr, KV_GROUP_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
BLOCK_HEAD: tl.constexpr, BLOCK_TOKENS: tl.constexpr, # token range length
BLOCK_TOKENS: tl.constexpr,
): ):
block_head_index = tl.program_id(0) cur_head_idx = tl.program_id(0)
block_token_index = tl.program_id(1) cur_token_block_idx = tl.program_id(1)
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
tokens_range = cur_token_block_idx * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
off_q0 = ( off_q0 = (
tokens_range[:, None, None] * q_token_stride tokens_range[:, None, None] * q_token_stride
+ head_range[None, :, None] * q_head_stride + cur_head_idx * q_head_stride
+ dim_range0[None, None, :] * head_dim_stride + dim_range0[None, None, :] * head_dim_stride
) )
off_q1 = ( off_q1 = (
tokens_range[:, None, None] * q_token_stride tokens_range[:, None, None] * q_token_stride
+ head_range[None, :, None] * q_head_stride + cur_head_idx * q_head_stride
+ dim_range1[None, None, :] * head_dim_stride + dim_range1[None, None, :] * head_dim_stride
) )
off_k0 = (
tokens_range[:, None, None] * k_token_stride
+ head_range[None, :, None] * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
tokens_range[:, None, None] * k_token_stride
+ head_range[None, :, None] * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
loaded_q0 = tl.load( loaded_q0 = tl.load(
q + off_q0, q + off_q0,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0, other=0.0,
) )
loaded_q1 = tl.load( loaded_q1 = tl.load(
q + off_q1, q + off_q1,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0, other=0.0,
) )
loaded_k0 = tl.load(
k + off_k0,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
loaded_k1 = tl.load(
k + off_k1,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :] out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :] out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
# concat
tl.store( tl.store(
q + off_q0, q + off_q0,
out_q0, out_q0,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
) )
tl.store( tl.store(
q + off_q1, q + off_q1,
out_q1, out_q1,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
k + off_k0,
out_k0,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
k + off_k1,
out_k1,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
) )
handle_k = cur_head_idx % KV_GROUP_NUM == 0
if handle_k:
k_head_idx = cur_head_idx // KV_GROUP_NUM
off_k0 = (
tokens_range[:, None, None] * k_token_stride
+ k_head_idx * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
tokens_range[:, None, None] * k_token_stride
+ k_head_idx * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
loaded_k0 = tl.load(
k + off_k0,
mask=(tokens_range[:, None, None] < q_total_tokens),
other=0.0,
)
loaded_k1 = tl.load(
k + off_k1,
mask=(tokens_range[:, None, None] < q_total_tokens),
other=0.0,
)
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
tl.store(
k + off_k0,
out_k0,
mask=(tokens_range[:, None, None] < q_total_tokens),
)
tl.store(
k + off_k1,
out_k1,
mask=(tokens_range[:, None, None] < q_total_tokens),
)
@triton.jit @triton.jit
def fused_rotary_embedding_kernel( def fused_rotary_embedding_kernel(
@ -405,108 +399,74 @@ def decoding_fused_rotary_embedding_kernel(
bts_stride, bts_stride,
btb_stride, btb_stride,
block_size, block_size,
Q_HEAD_NUM: tl.constexpr, KV_GROUP_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr, HEAD_DIM: tl.constexpr,
): ):
block_head_index = tl.program_id(0) cur_head_idx = tl.program_id(0)
if block_head_index >= Q_HEAD_NUM: cur_token_idx = tl.program_id(1)
return
block_token_index = tl.program_id(1)
dim_range = tl.arange(0, HEAD_DIM)
dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
total_dim_range = tl.arange(0, HEAD_DIM)
q_off_base = block_token_index * q_token_stride + block_head_index * q_head_stride
off_q0 = q_off_base + dim_range0 * head_dim_stride
off_q1 = q_off_base + dim_range1 * head_dim_stride
off_base = block_token_index * k_token_stride + block_head_index * k_head_stride
off_k0 = off_base + dim_range0 * head_dim_stride
off_k1 = off_base + dim_range1 * head_dim_stride
off_v = off_base + total_dim_range * head_dim_stride
loaded_q0 = tl.load(
q + off_q0,
)
loaded_q1 = tl.load(
q + off_q1,
)
loaded_k0 = tl.load( off_q = cur_token_idx * q_token_stride + cur_head_idx * q_head_stride
k + off_k0, off_q0 = off_q + dim_range0 * head_dim_stride
) off_q1 = off_q + dim_range1 * head_dim_stride
loaded_k1 = tl.load(
k + off_k1,
)
loaded_v = tl.load(
v + off_v,
)
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
loaded_q0 = tl.load(q + off_q0)
loaded_q1 = tl.load(q + off_q1)
off_cos_sin = cur_token_idx * cos_token_stride + dim_range0 * cos_stride
loaded_cos = tl.load(cos + off_cos_sin) loaded_cos = tl.load(cos + off_cos_sin)
loaded_sin = tl.load(sin + off_cos_sin) loaded_sin = tl.load(sin + off_cos_sin)
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
tl.store(q + off_q0, out_q0)
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin tl.store(q + off_q1, out_q1)
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
handle_k = cur_head_idx % KV_GROUP_NUM == 0
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1 if handle_k:
cur_k_head_idx = cur_head_idx // KV_GROUP_NUM
last_block_idx = past_kv_seq_len // block_size off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride
block_ids = tl.load(BLOCK_TABLES + block_token_index * bts_stride + last_block_idx * btb_stride) off_k0 = off_kv + dim_range0 * head_dim_stride
offsets_in_last_block = past_kv_seq_len % block_size off_k1 = off_kv + dim_range1 * head_dim_stride
loaded_k0 = tl.load(k + off_k0)
k_range0 = ( loaded_k1 = tl.load(k + off_k1)
block_ids * cache_b_stride
+ block_head_index * cache_h_stride out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
+ offsets_in_last_block * cache_bs_stride out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos
+ dim_range0 * cache_d_stride
) # NOTE The precondition here is that it's only for unpadded inputs during decoding stage,
k_range1 = ( # and so that we could directly use the token index as the sequence index
block_ids * cache_b_stride past_kv_seq_len = tl.load(context_lengths + cur_token_idx) - 1
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride last_block_idx = past_kv_seq_len // block_size
+ dim_range1 * cache_d_stride block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride)
) offsets_in_last_block = past_kv_seq_len % block_size
v_range = ( k_range0 = (
block_ids * cache_b_stride block_ids * cache_b_stride
+ block_head_index * cache_h_stride + cur_k_head_idx * cache_h_stride
+ offsets_in_last_block * cache_bs_stride + offsets_in_last_block * cache_bs_stride
+ total_dim_range * cache_d_stride + dim_range0 * cache_d_stride
) )
k_range1 = (
tl.store( block_ids * cache_b_stride
v_cache + v_range, + cur_k_head_idx * cache_h_stride
loaded_v, + offsets_in_last_block * cache_bs_stride
) + dim_range1 * cache_d_stride
)
tl.store( tl.store(k_cache + k_range0, out_k0)
k_cache + k_range0, tl.store(k_cache + k_range1, out_k1)
out_k0,
) off_v = off_kv + dim_range * head_dim_stride
loaded_v = tl.load(v + off_v)
tl.store( v_range = (
k_cache + k_range1, block_ids * cache_b_stride
out_k1, + cur_k_head_idx * cache_h_stride
) + offsets_in_last_block * cache_bs_stride
+ dim_range * cache_d_stride
# concat )
tl.store( tl.store(v_cache + v_range, loaded_v)
q + off_q0,
out_q0,
)
tl.store(
q + off_q1,
out_q1,
)
def rotary_embedding( def rotary_embedding(
@ -521,7 +481,7 @@ def rotary_embedding(
""" """
Args: Args:
q: query tensor, [total_tokens, head_num, head_dim] q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim] k: key tensor, [total_tokens, kv_head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim] cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim] sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
@ -530,32 +490,26 @@ def rotary_embedding(
""" """
q_total_tokens, q_head_num, head_dim = q.shape q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 4 BLOCK_TOKENS = 4
if head_dim >= 1024: if head_dim >= 512:
num_warps = 32
elif head_dim >= 512:
num_warps = 16 num_warps = 16
elif head_dim >= 256: elif head_dim >= 256:
num_warps = 8 num_warps = 8
else: else:
num_warps = 4 num_warps = 4
q_token_stride = q.stride(0) k_head_num = k.size(1)
q_head_stride = q.stride(1) q_token_stride, q_head_stride, head_dim_stride = q.stride()
head_dim_stride = q.stride(2) k_token_stride, k_head_stride, _ = k.stride()
cos_token_stride, cos_stride = cos.stride()
k_token_stride = k.stride(0)
k_head_stride = k.stride(1)
k_head_num = q.shape[1] assert q_head_num % k_head_num == 0
kv_group_num = q_head_num // k_head_num
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
if k_cache == None: if k_cache == None:
grid = lambda META: ( grid = lambda META: (
triton.cdiv(q_head_num, META["BLOCK_HEAD"]), q_head_num,
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
) )
rotary_embedding_kernel[grid]( rotary_embedding_kernel[grid](
@ -572,9 +526,8 @@ def rotary_embedding(
cos_stride, cos_stride,
q_total_tokens, q_total_tokens,
Q_HEAD_NUM=q_head_num, Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num, KV_GROUP_NUM=kv_group_num,
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_TOKENS=BLOCK_TOKENS, BLOCK_TOKENS=BLOCK_TOKENS,
num_warps=num_warps, num_warps=num_warps,
) )
@ -624,23 +577,21 @@ def decoding_fused_rotary_embedding(
""" """
Args: Args:
q: query tensor, [total_tokens, head_num, head_dim] q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim] k: key tensor, [total_tokens, kv_head_num, head_dim]
v: value tensor, [total tokens, head_num, head_dim] v: value tensor, [total tokens, kv_head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim] cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim] sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] k_cache (torch.Tensor): Blocked key cache. [num_blocks, kv_head_num, block_size, head_dim]
v_cache (torch.Tensor): Blocked value cache. [num_blocks, num_kv_heads, block_size, head_dim] v_cache (torch.Tensor): Blocked value cache. [num_blocks, kv_head_num, block_size, head_dim]
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz] kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence] block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
""" """
q_total_tokens, q_head_num, head_dim = q.shape q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) == v.size(0) assert q.size(0) == k.size(0) == v.size(0)
assert q.size(1) == k.size(1) == v.size(1) assert k.size(1) == v.size(1)
assert k_cache.size(-1) == v_cache.size(-1) assert k_cache.size(-1) == v_cache.size(-1)
if head_dim >= 1024: if head_dim >= 512:
num_warps = 32
elif head_dim >= 512:
num_warps = 16 num_warps = 16
elif head_dim >= 256: elif head_dim >= 256:
num_warps = 8 num_warps = 8
@ -653,10 +604,12 @@ def decoding_fused_rotary_embedding(
k_token_stride = k.stride(0) k_token_stride = k.stride(0)
k_head_stride = k.stride(1) k_head_stride = k.stride(1)
k_head_num = k.size(1)
kv_group_num = q_head_num // k_head_num
cos_token_stride = cos.stride(0) cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1) cos_stride = cos.stride(1)
grid = (triton.next_power_of_2(q_head_num), q_total_tokens) grid = (q_head_num, q_total_tokens)
decoding_fused_rotary_embedding_kernel[grid]( decoding_fused_rotary_embedding_kernel[grid](
q, q,
k, k,
@ -681,7 +634,7 @@ def decoding_fused_rotary_embedding(
block_tables.stride(0), block_tables.stride(0),
block_tables.stride(1), block_tables.stride(1),
k_cache.size(-2), k_cache.size(-2),
Q_HEAD_NUM=q_head_num, KV_GROUP_NUM=kv_group_num,
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
num_warps=num_warps, num_warps=num_warps,
) )

@ -133,8 +133,9 @@ def check_spec_dec(num_layers, max_length):
assert not engine.use_spec_dec assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None assert engine.drafter is None and engine.drafter_model is None
max_new_tokens = max_length - dummy_inputs.size(1)
assert len(out) == 1 assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
# test GLIDE model # test GLIDE model
glide_config = GlideLlamaConfig( glide_config = GlideLlamaConfig(
@ -152,7 +153,7 @@ def check_spec_dec(num_layers, max_length):
engine.clear_spec_dec() engine.clear_spec_dec()
assert len(out) == 1 assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
@ -186,7 +187,7 @@ def test_tp_engine(prompt_template, do_sample):
@parameterize("num_layers", [1]) @parameterize("num_layers", [1])
@parameterize("max_length", [100]) @parameterize("max_length", [64])
def test_spec_dec(num_layers, max_length): def test_spec_dec(num_layers, max_length):
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length) spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)

@ -151,6 +151,16 @@ def test_flash_decoding_attention(
numpy_allclose(out_ref, output, rtol=rtol, atol=atol) numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
try:
from vllm._C import ops as vllm_ops # noqa
HAS_VLLM = True
except ImportError:
HAS_VLLM = False
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) @pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) @pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32]) @pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@ -166,11 +176,6 @@ def test_vllm_flash_decoding_attention(
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
try:
from vllm._C import ops as vllm_ops
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ

Loading…
Cancel
Save