[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)

* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5629/head
Yuanheng Zhao 7 months ago committed by GitHub
parent ccf72797e3
commit 5d4c1fe8f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -386,6 +386,7 @@ class BatchBucket:
seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens]
seq.revoke_finished_status()
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:

@ -518,7 +518,13 @@ class InferenceEngine:
"""
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
@ -573,6 +579,7 @@ class InferenceEngine:
request_ids: List[int] = None,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None:
"""
Add requests.
@ -629,6 +636,13 @@ class InferenceEngine:
else:
prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
sequence = Sequence(
request_id,
prompt,
@ -637,7 +651,7 @@ class InferenceEngine:
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
max_output_len=max_new_tokens,
)
self.request_handler.add_sequence(sequence)

@ -314,10 +314,11 @@ class RequestHandler:
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li:
if (
seq.output_token_id[-1] == generation_config.eos_token_id
or seq.output_len >= generation_config.max_length
):
max_length = generation_config.max_length
max_new_tokens = generation_config.max_new_tokens
if max_length is not None:
max_new_tokens = max_length - seq.input_len
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()
def check_unfinished_seqs(self) -> bool:

@ -38,7 +38,7 @@ class KVCacheManager:
The block table after block allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 |
Then the logical blocks with id 0, 1, and 2, are allocated for this sequence,
and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer,
and the physical caches, each with size of `block_size * kv_head_num * head_size * elem_size` for a single layer,
corresponding to these blocks will be used to read/write KV Caches in kernels.
For a batch of sequences, the block tables after allocation might be:
@ -64,9 +64,12 @@ class KVCacheManager:
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size
assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
self.kv_head_num //= self.tp_size
self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len
@ -80,9 +83,8 @@ class KVCacheManager:
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
# if verbose:
# self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
@ -90,9 +92,12 @@ class KVCacheManager:
* 2
* self.num_blocks
* self.block_size
* self.head_num
* self.kv_head_num
* self.head_size
)
self.logger.info(
f"Allocated {self.total_physical_cache_size_in_bytes / GIGABYTE:.2f} GB of KV cache on device {self.device}."
)
# Logical cache blocks allocation
self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches())
@ -453,7 +458,7 @@ class KVCacheManager:
"""
assert self._kv_caches is not None and len(self._kv_caches[0]) > 0
blocks = []
physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size
physical_block_size = self.elem_size_in_bytes * self.block_size * self.kv_head_num * self.head_size
k_ptrs = [
self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
]

@ -447,9 +447,9 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
else:
self.q_proj_weight = attn_qproj_w
self.k_proj_weight = attn_kproj_w
self.v_proj_weight = attn_vproj_w
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
@staticmethod
def from_native_module(
@ -638,6 +638,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)

@ -117,6 +117,14 @@ class Sequence:
return False
def revoke_finished_status(self) -> None:
"""
Revoke the finished status of the sequence.
This is only used by speculative decoding for now.
"""
if RequestStatus.is_finished(self.status):
self.status = RequestStatus.RUNNING
def __hash__(self):
return hash(self.request_id)

@ -36,97 +36,91 @@ def rotary_embedding_kernel(
cos_stride,
q_total_tokens,
Q_HEAD_NUM: tl.constexpr,
K_HEAD_NUM: tl.constexpr,
KV_GROUP_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
BLOCK_TOKENS: tl.constexpr, # token range length
):
block_head_index = tl.program_id(0)
block_token_index = tl.program_id(1)
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
cur_head_idx = tl.program_id(0)
cur_token_block_idx = tl.program_id(1)
tokens_range = cur_token_block_idx * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
off_q0 = (
tokens_range[:, None, None] * q_token_stride
+ head_range[None, :, None] * q_head_stride
+ cur_head_idx * q_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_q1 = (
tokens_range[:, None, None] * q_token_stride
+ head_range[None, :, None] * q_head_stride
+ cur_head_idx * q_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
off_k0 = (
tokens_range[:, None, None] * k_token_stride
+ head_range[None, :, None] * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
tokens_range[:, None, None] * k_token_stride
+ head_range[None, :, None] * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
loaded_q0 = tl.load(
q + off_q0,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
loaded_q1 = tl.load(
q + off_q1,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
loaded_k0 = tl.load(
k + off_k0,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
loaded_k1 = tl.load(
k + off_k1,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
# concat
tl.store(
q + off_q0,
out_q0,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
q + off_q1,
out_q1,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
k + off_k0,
out_k0,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
k + off_k1,
out_k1,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
handle_k = cur_head_idx % KV_GROUP_NUM == 0
if handle_k:
k_head_idx = cur_head_idx // KV_GROUP_NUM
off_k0 = (
tokens_range[:, None, None] * k_token_stride
+ k_head_idx * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
tokens_range[:, None, None] * k_token_stride
+ k_head_idx * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
loaded_k0 = tl.load(
k + off_k0,
mask=(tokens_range[:, None, None] < q_total_tokens),
other=0.0,
)
loaded_k1 = tl.load(
k + off_k1,
mask=(tokens_range[:, None, None] < q_total_tokens),
other=0.0,
)
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
tl.store(
k + off_k0,
out_k0,
mask=(tokens_range[:, None, None] < q_total_tokens),
)
tl.store(
k + off_k1,
out_k1,
mask=(tokens_range[:, None, None] < q_total_tokens),
)
@triton.jit
def fused_rotary_embedding_kernel(
@ -405,108 +399,74 @@ def decoding_fused_rotary_embedding_kernel(
bts_stride,
btb_stride,
block_size,
Q_HEAD_NUM: tl.constexpr,
KV_GROUP_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
block_head_index = tl.program_id(0)
if block_head_index >= Q_HEAD_NUM:
return
block_token_index = tl.program_id(1)
cur_head_idx = tl.program_id(0)
cur_token_idx = tl.program_id(1)
dim_range = tl.arange(0, HEAD_DIM)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
total_dim_range = tl.arange(0, HEAD_DIM)
q_off_base = block_token_index * q_token_stride + block_head_index * q_head_stride
off_q0 = q_off_base + dim_range0 * head_dim_stride
off_q1 = q_off_base + dim_range1 * head_dim_stride
off_base = block_token_index * k_token_stride + block_head_index * k_head_stride
off_k0 = off_base + dim_range0 * head_dim_stride
off_k1 = off_base + dim_range1 * head_dim_stride
off_v = off_base + total_dim_range * head_dim_stride
loaded_q0 = tl.load(
q + off_q0,
)
loaded_q1 = tl.load(
q + off_q1,
)
loaded_k0 = tl.load(
k + off_k0,
)
loaded_k1 = tl.load(
k + off_k1,
)
loaded_v = tl.load(
v + off_v,
)
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
off_q = cur_token_idx * q_token_stride + cur_head_idx * q_head_stride
off_q0 = off_q + dim_range0 * head_dim_stride
off_q1 = off_q + dim_range1 * head_dim_stride
loaded_q0 = tl.load(q + off_q0)
loaded_q1 = tl.load(q + off_q1)
off_cos_sin = cur_token_idx * cos_token_stride + dim_range0 * cos_stride
loaded_cos = tl.load(cos + off_cos_sin)
loaded_sin = tl.load(sin + off_cos_sin)
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1
last_block_idx = past_kv_seq_len // block_size
block_ids = tl.load(BLOCK_TABLES + block_token_index * bts_stride + last_block_idx * btb_stride)
offsets_in_last_block = past_kv_seq_len % block_size
k_range0 = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range0 * cache_d_stride
)
k_range1 = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range1 * cache_d_stride
)
v_range = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ total_dim_range * cache_d_stride
)
tl.store(
v_cache + v_range,
loaded_v,
)
tl.store(
k_cache + k_range0,
out_k0,
)
tl.store(
k_cache + k_range1,
out_k1,
)
# concat
tl.store(
q + off_q0,
out_q0,
)
tl.store(
q + off_q1,
out_q1,
)
tl.store(q + off_q0, out_q0)
tl.store(q + off_q1, out_q1)
handle_k = cur_head_idx % KV_GROUP_NUM == 0
if handle_k:
cur_k_head_idx = cur_head_idx // KV_GROUP_NUM
off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride
off_k0 = off_kv + dim_range0 * head_dim_stride
off_k1 = off_kv + dim_range1 * head_dim_stride
loaded_k0 = tl.load(k + off_k0)
loaded_k1 = tl.load(k + off_k1)
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos
# NOTE The precondition here is that it's only for unpadded inputs during decoding stage,
# and so that we could directly use the token index as the sequence index
past_kv_seq_len = tl.load(context_lengths + cur_token_idx) - 1
last_block_idx = past_kv_seq_len // block_size
block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride)
offsets_in_last_block = past_kv_seq_len % block_size
k_range0 = (
block_ids * cache_b_stride
+ cur_k_head_idx * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range0 * cache_d_stride
)
k_range1 = (
block_ids * cache_b_stride
+ cur_k_head_idx * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range1 * cache_d_stride
)
tl.store(k_cache + k_range0, out_k0)
tl.store(k_cache + k_range1, out_k1)
off_v = off_kv + dim_range * head_dim_stride
loaded_v = tl.load(v + off_v)
v_range = (
block_ids * cache_b_stride
+ cur_k_head_idx * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range * cache_d_stride
)
tl.store(v_cache + v_range, loaded_v)
def rotary_embedding(
@ -521,7 +481,7 @@ def rotary_embedding(
"""
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, kv_head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
@ -530,32 +490,26 @@ def rotary_embedding(
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 4
if head_dim >= 1024:
num_warps = 32
elif head_dim >= 512:
if head_dim >= 512:
num_warps = 16
elif head_dim >= 256:
num_warps = 8
else:
num_warps = 4
q_token_stride = q.stride(0)
q_head_stride = q.stride(1)
head_dim_stride = q.stride(2)
k_token_stride = k.stride(0)
k_head_stride = k.stride(1)
k_head_num = k.size(1)
q_token_stride, q_head_stride, head_dim_stride = q.stride()
k_token_stride, k_head_stride, _ = k.stride()
cos_token_stride, cos_stride = cos.stride()
k_head_num = q.shape[1]
assert q_head_num % k_head_num == 0
kv_group_num = q_head_num // k_head_num
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
if k_cache == None:
grid = lambda META: (
triton.cdiv(q_head_num, META["BLOCK_HEAD"]),
q_head_num,
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
)
rotary_embedding_kernel[grid](
@ -572,9 +526,8 @@ def rotary_embedding(
cos_stride,
q_total_tokens,
Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num,
KV_GROUP_NUM=kv_group_num,
HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_TOKENS=BLOCK_TOKENS,
num_warps=num_warps,
)
@ -624,23 +577,21 @@ def decoding_fused_rotary_embedding(
"""
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
v: value tensor, [total tokens, head_num, head_dim]
k: key tensor, [total_tokens, kv_head_num, head_dim]
v: value tensor, [total tokens, kv_head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): Blocked value cache. [num_blocks, num_kv_heads, block_size, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, kv_head_num, block_size, head_dim]
v_cache (torch.Tensor): Blocked value cache. [num_blocks, kv_head_num, block_size, head_dim]
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) == v.size(0)
assert q.size(1) == k.size(1) == v.size(1)
assert k.size(1) == v.size(1)
assert k_cache.size(-1) == v_cache.size(-1)
if head_dim >= 1024:
num_warps = 32
elif head_dim >= 512:
if head_dim >= 512:
num_warps = 16
elif head_dim >= 256:
num_warps = 8
@ -653,10 +604,12 @@ def decoding_fused_rotary_embedding(
k_token_stride = k.stride(0)
k_head_stride = k.stride(1)
k_head_num = k.size(1)
kv_group_num = q_head_num // k_head_num
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
grid = (q_head_num, q_total_tokens)
decoding_fused_rotary_embedding_kernel[grid](
q,
k,
@ -681,7 +634,7 @@ def decoding_fused_rotary_embedding(
block_tables.stride(0),
block_tables.stride(1),
k_cache.size(-2),
Q_HEAD_NUM=q_head_num,
KV_GROUP_NUM=kv_group_num,
HEAD_DIM=head_dim,
num_warps=num_warps,
)

@ -133,8 +133,9 @@ def check_spec_dec(num_layers, max_length):
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
max_new_tokens = max_length - dummy_inputs.size(1)
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
# test GLIDE model
glide_config = GlideLlamaConfig(
@ -152,7 +153,7 @@ def check_spec_dec(num_layers, max_length):
engine.clear_spec_dec()
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
@ -186,7 +187,7 @@ def test_tp_engine(prompt_template, do_sample):
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
@parameterize("max_length", [64])
def test_spec_dec(num_layers, max_length):
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)

@ -151,6 +151,16 @@ def test_flash_decoding_attention(
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
try:
from vllm._C import ops as vllm_ops # noqa
HAS_VLLM = True
except ImportError:
HAS_VLLM = False
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@ -166,11 +176,6 @@ def test_vllm_flash_decoding_attention(
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
try:
from vllm._C import ops as vllm_ops
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ

Loading…
Cancel
Save