From 5d4c1fe8f5f7019284f6cbc0ed29506748f63bf1 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:09:55 +0800 Subject: [PATCH] [Fix/Inference] Fix GQA Triton and Support Llama3 (#5624) * [fix] GQA calling of flash decoding triton * fix kv cache alloc shape * fix rotary triton - GQA * fix sequence max length assigning * Sequence max length logic * fix scheduling and spec-dec * skip without import error * fix pytest - skip without ImportError --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/inference/batch_bucket.py | 1 + colossalai/inference/core/engine.py | 18 +- colossalai/inference/core/request_handler.py | 9 +- .../inference/kv_cache/kvcache_manager.py | 21 +- .../modeling/models/nopadding_llama.py | 7 +- colossalai/inference/struct.py | 8 + .../kernel/triton/no_pad_rotary_embedding.py | 301 ++++++++---------- tests/test_infer/test_inference_engine.py | 7 +- .../cuda/test_flash_decoding_attention.py | 15 +- 9 files changed, 188 insertions(+), 199 deletions(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index a2a2e74e8..726dfd614 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -386,6 +386,7 @@ class BatchBucket: seq_id, seq = next(seqs_iter) assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence" seq.output_token_id = seq.output_token_id[:-n_tokens] + seq.revoke_finished_status() self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index c30db3e0c..557a32fb6 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -518,7 +518,13 @@ class InferenceEngine: """ with torch.inference_mode(): if prompts is not None or prompts_token_ids is not None: - self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + self.add_request( + request_ids=request_ids, + prompts=prompts, + prompts_token_ids=prompts_token_ids, + **gen_config_dict, + ) output_seqs_list = [] total_tokens_list = [] @@ -573,6 +579,7 @@ class InferenceEngine: request_ids: List[int] = None, prompts: List[str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + **kwargs, ) -> None: """ Add requests. @@ -629,6 +636,13 @@ class InferenceEngine: else: prompt = prompts[i] + max_length = kwargs.get("max_length", None) + max_new_tokens = kwargs.get("max_new_tokens", None) + if max_length is None and max_new_tokens is None: + max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len + elif max_length is not None: + max_new_tokens = max_length - len(prompts_token_ids[i]) + sequence = Sequence( request_id, prompt, @@ -637,7 +651,7 @@ class InferenceEngine: None, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, - self.inference_config.max_output_len, + max_output_len=max_new_tokens, ) self.request_handler.add_sequence(sequence) diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 61ae3a4df..d80572599 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -314,10 +314,11 @@ class RequestHandler: def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig): for seq in batch.seqs_li: - if ( - seq.output_token_id[-1] == generation_config.eos_token_id - or seq.output_len >= generation_config.max_length - ): + max_length = generation_config.max_length + max_new_tokens = generation_config.max_new_tokens + if max_length is not None: + max_new_tokens = max_length - seq.input_len + if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens: seq.mark_finished() def check_unfinished_seqs(self) -> bool: diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 2b6445d1c..27ceca426 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -38,7 +38,7 @@ class KVCacheManager: The block table after block allocation might be: | 0 | 1 | 2 | -1 | -1 | -1 | Then the logical blocks with id 0, 1, and 2, are allocated for this sequence, - and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer, + and the physical caches, each with size of `block_size * kv_head_num * head_size * elem_size` for a single layer, corresponding to these blocks will be used to read/write KV Caches in kernels. For a batch of sequences, the block tables after allocation might be: @@ -64,9 +64,12 @@ class KVCacheManager: self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") self.head_num = get_model_config_attr(model_config, "num_attention_heads") + self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads") self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num - assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" - self.head_num //= self.tp_size + assert ( + self.kv_head_num % self.tp_size == 0 + ), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}" + self.kv_head_num //= self.tp_size self.beam_width = config.beam_width self.max_batch_size = config.max_batch_size self.max_input_length = config.max_input_len @@ -80,9 +83,8 @@ class KVCacheManager: self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Physical cache allocation - alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size) - # if verbose: - # self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") + alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) + self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") self._kv_caches = self._init_device_caches(alloc_shape) self.total_physical_cache_size_in_bytes = ( self.elem_size_in_bytes @@ -90,9 +92,12 @@ class KVCacheManager: * 2 * self.num_blocks * self.block_size - * self.head_num + * self.kv_head_num * self.head_size ) + self.logger.info( + f"Allocated {self.total_physical_cache_size_in_bytes / GIGABYTE:.2f} GB of KV cache on device {self.device}." + ) # Logical cache blocks allocation self._available_blocks = self.num_blocks self._cache_blocks = tuple(self._init_logical_caches()) @@ -453,7 +458,7 @@ class KVCacheManager: """ assert self._kv_caches is not None and len(self._kv_caches[0]) > 0 blocks = [] - physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size + physical_block_size = self.elem_size_in_bytes * self.block_size * self.kv_head_num * self.head_size k_ptrs = [ self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers) ] diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index be05e0838..ff5a159cd 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -447,9 +447,9 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention): attn_qproj_w.dist_layout ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) else: - self.q_proj_weight = attn_qproj_w - self.k_proj_weight = attn_kproj_w - self.v_proj_weight = attn_vproj_w + self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous()) + self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous()) + self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous()) @staticmethod def from_native_module( @@ -638,6 +638,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention): mid_output=fd_inter_tensor.mid_output, mid_output_lse=fd_inter_tensor.mid_output_lse, sm_scale=sm_scale, + kv_group_num=self.num_key_value_groups, q_len=q_len, ) diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 1fe732df0..fade655e1 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -117,6 +117,14 @@ class Sequence: return False + def revoke_finished_status(self) -> None: + """ + Revoke the finished status of the sequence. + This is only used by speculative decoding for now. + """ + if RequestStatus.is_finished(self.status): + self.status = RequestStatus.RUNNING + def __hash__(self): return hash(self.request_id) diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 4b294a399..ad3946353 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -36,97 +36,91 @@ def rotary_embedding_kernel( cos_stride, q_total_tokens, Q_HEAD_NUM: tl.constexpr, - K_HEAD_NUM: tl.constexpr, + KV_GROUP_NUM: tl.constexpr, HEAD_DIM: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - BLOCK_TOKENS: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, # token range length ): - block_head_index = tl.program_id(0) - block_token_index = tl.program_id(1) - - tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) - head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + cur_head_idx = tl.program_id(0) + cur_token_block_idx = tl.program_id(1) + tokens_range = cur_token_block_idx * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride + loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) + loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) + off_q0 = ( tokens_range[:, None, None] * q_token_stride - + head_range[None, :, None] * q_head_stride + + cur_head_idx * q_head_stride + dim_range0[None, None, :] * head_dim_stride ) off_q1 = ( tokens_range[:, None, None] * q_token_stride - + head_range[None, :, None] * q_head_stride + + cur_head_idx * q_head_stride + dim_range1[None, None, :] * head_dim_stride ) - off_k0 = ( - tokens_range[:, None, None] * k_token_stride - + head_range[None, :, None] * k_head_stride - + dim_range0[None, None, :] * head_dim_stride - ) - off_k1 = ( - tokens_range[:, None, None] * k_token_stride - + head_range[None, :, None] * k_head_stride - + dim_range1[None, None, :] * head_dim_stride - ) - loaded_q0 = tl.load( q + off_q0, - mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), other=0.0, ) loaded_q1 = tl.load( q + off_q1, - mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), other=0.0, ) - - loaded_k0 = tl.load( - k + off_k0, - mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), - other=0.0, - ) - - loaded_k1 = tl.load( - k + off_k1, - mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), - other=0.0, - ) - - off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride - - loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) - loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) - out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :] out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :] - out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] - out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] - - # concat tl.store( q + off_q0, out_q0, - mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), ) tl.store( q + off_q1, out_q1, - mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), - ) - tl.store( - k + off_k0, - out_k0, - mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), - ) - tl.store( - k + off_k1, - out_k1, - mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), ) + handle_k = cur_head_idx % KV_GROUP_NUM == 0 + if handle_k: + k_head_idx = cur_head_idx // KV_GROUP_NUM + off_k0 = ( + tokens_range[:, None, None] * k_token_stride + + k_head_idx * k_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_k1 = ( + tokens_range[:, None, None] * k_token_stride + + k_head_idx * k_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + loaded_k0 = tl.load( + k + off_k0, + mask=(tokens_range[:, None, None] < q_total_tokens), + other=0.0, + ) + loaded_k1 = tl.load( + k + off_k1, + mask=(tokens_range[:, None, None] < q_total_tokens), + other=0.0, + ) + out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] + out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] + tl.store( + k + off_k0, + out_k0, + mask=(tokens_range[:, None, None] < q_total_tokens), + ) + tl.store( + k + off_k1, + out_k1, + mask=(tokens_range[:, None, None] < q_total_tokens), + ) + @triton.jit def fused_rotary_embedding_kernel( @@ -405,108 +399,74 @@ def decoding_fused_rotary_embedding_kernel( bts_stride, btb_stride, block_size, - Q_HEAD_NUM: tl.constexpr, + KV_GROUP_NUM: tl.constexpr, HEAD_DIM: tl.constexpr, ): - block_head_index = tl.program_id(0) - if block_head_index >= Q_HEAD_NUM: - return - - block_token_index = tl.program_id(1) + cur_head_idx = tl.program_id(0) + cur_token_idx = tl.program_id(1) + dim_range = tl.arange(0, HEAD_DIM) dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - total_dim_range = tl.arange(0, HEAD_DIM) - - q_off_base = block_token_index * q_token_stride + block_head_index * q_head_stride - off_q0 = q_off_base + dim_range0 * head_dim_stride - off_q1 = q_off_base + dim_range1 * head_dim_stride - - off_base = block_token_index * k_token_stride + block_head_index * k_head_stride - off_k0 = off_base + dim_range0 * head_dim_stride - off_k1 = off_base + dim_range1 * head_dim_stride - - off_v = off_base + total_dim_range * head_dim_stride - - loaded_q0 = tl.load( - q + off_q0, - ) - loaded_q1 = tl.load( - q + off_q1, - ) - loaded_k0 = tl.load( - k + off_k0, - ) - - loaded_k1 = tl.load( - k + off_k1, - ) - - loaded_v = tl.load( - v + off_v, - ) - - off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride + off_q = cur_token_idx * q_token_stride + cur_head_idx * q_head_stride + off_q0 = off_q + dim_range0 * head_dim_stride + off_q1 = off_q + dim_range1 * head_dim_stride + loaded_q0 = tl.load(q + off_q0) + loaded_q1 = tl.load(q + off_q1) + off_cos_sin = cur_token_idx * cos_token_stride + dim_range0 * cos_stride loaded_cos = tl.load(cos + off_cos_sin) loaded_sin = tl.load(sin + off_cos_sin) out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos - - out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin - out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim - - past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1 - - last_block_idx = past_kv_seq_len // block_size - block_ids = tl.load(BLOCK_TABLES + block_token_index * bts_stride + last_block_idx * btb_stride) - offsets_in_last_block = past_kv_seq_len % block_size - - k_range0 = ( - block_ids * cache_b_stride - + block_head_index * cache_h_stride - + offsets_in_last_block * cache_bs_stride - + dim_range0 * cache_d_stride - ) - k_range1 = ( - block_ids * cache_b_stride - + block_head_index * cache_h_stride - + offsets_in_last_block * cache_bs_stride - + dim_range1 * cache_d_stride - ) - v_range = ( - block_ids * cache_b_stride - + block_head_index * cache_h_stride - + offsets_in_last_block * cache_bs_stride - + total_dim_range * cache_d_stride - ) - - tl.store( - v_cache + v_range, - loaded_v, - ) - - tl.store( - k_cache + k_range0, - out_k0, - ) - - tl.store( - k_cache + k_range1, - out_k1, - ) - - # concat - tl.store( - q + off_q0, - out_q0, - ) - tl.store( - q + off_q1, - out_q1, - ) + tl.store(q + off_q0, out_q0) + tl.store(q + off_q1, out_q1) + + handle_k = cur_head_idx % KV_GROUP_NUM == 0 + if handle_k: + cur_k_head_idx = cur_head_idx // KV_GROUP_NUM + off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride + off_k0 = off_kv + dim_range0 * head_dim_stride + off_k1 = off_kv + dim_range1 * head_dim_stride + loaded_k0 = tl.load(k + off_k0) + loaded_k1 = tl.load(k + off_k1) + + out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin + out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos + + # NOTE The precondition here is that it's only for unpadded inputs during decoding stage, + # and so that we could directly use the token index as the sequence index + past_kv_seq_len = tl.load(context_lengths + cur_token_idx) - 1 + + last_block_idx = past_kv_seq_len // block_size + block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride) + offsets_in_last_block = past_kv_seq_len % block_size + k_range0 = ( + block_ids * cache_b_stride + + cur_k_head_idx * cache_h_stride + + offsets_in_last_block * cache_bs_stride + + dim_range0 * cache_d_stride + ) + k_range1 = ( + block_ids * cache_b_stride + + cur_k_head_idx * cache_h_stride + + offsets_in_last_block * cache_bs_stride + + dim_range1 * cache_d_stride + ) + tl.store(k_cache + k_range0, out_k0) + tl.store(k_cache + k_range1, out_k1) + + off_v = off_kv + dim_range * head_dim_stride + loaded_v = tl.load(v + off_v) + v_range = ( + block_ids * cache_b_stride + + cur_k_head_idx * cache_h_stride + + offsets_in_last_block * cache_bs_stride + + dim_range * cache_d_stride + ) + tl.store(v_cache + v_range, loaded_v) def rotary_embedding( @@ -521,7 +481,7 @@ def rotary_embedding( """ Args: q: query tensor, [total_tokens, head_num, head_dim] - k: key tensor, [total_tokens, head_num, head_dim] + k: key tensor, [total_tokens, kv_head_num, head_dim] cos: cosine for rotary embedding, [max_position_len, head_dim] sin: sine for rotary embedding, [max_position_len, head_dim] k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] @@ -530,32 +490,26 @@ def rotary_embedding( """ q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) - BLOCK_HEAD = 4 BLOCK_TOKENS = 4 - if head_dim >= 1024: - num_warps = 32 - elif head_dim >= 512: + if head_dim >= 512: num_warps = 16 elif head_dim >= 256: num_warps = 8 else: num_warps = 4 - q_token_stride = q.stride(0) - q_head_stride = q.stride(1) - head_dim_stride = q.stride(2) - - k_token_stride = k.stride(0) - k_head_stride = k.stride(1) + k_head_num = k.size(1) + q_token_stride, q_head_stride, head_dim_stride = q.stride() + k_token_stride, k_head_stride, _ = k.stride() + cos_token_stride, cos_stride = cos.stride() - k_head_num = q.shape[1] + assert q_head_num % k_head_num == 0 + kv_group_num = q_head_num // k_head_num - cos_token_stride = cos.stride(0) - cos_stride = cos.stride(1) if k_cache == None: grid = lambda META: ( - triton.cdiv(q_head_num, META["BLOCK_HEAD"]), + q_head_num, triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]), ) rotary_embedding_kernel[grid]( @@ -572,9 +526,8 @@ def rotary_embedding( cos_stride, q_total_tokens, Q_HEAD_NUM=q_head_num, - K_HEAD_NUM=k_head_num, + KV_GROUP_NUM=kv_group_num, HEAD_DIM=head_dim, - BLOCK_HEAD=BLOCK_HEAD, BLOCK_TOKENS=BLOCK_TOKENS, num_warps=num_warps, ) @@ -624,23 +577,21 @@ def decoding_fused_rotary_embedding( """ Args: q: query tensor, [total_tokens, head_num, head_dim] - k: key tensor, [total_tokens, head_num, head_dim] - v: value tensor, [total tokens, head_num, head_dim] + k: key tensor, [total_tokens, kv_head_num, head_dim] + v: value tensor, [total tokens, kv_head_num, head_dim] cos: cosine for rotary embedding, [max_position_len, head_dim] sin: sine for rotary embedding, [max_position_len, head_dim] - k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] - v_cache (torch.Tensor): Blocked value cache. [num_blocks, num_kv_heads, block_size, head_dim] + k_cache (torch.Tensor): Blocked key cache. [num_blocks, kv_head_num, block_size, head_dim] + v_cache (torch.Tensor): Blocked value cache. [num_blocks, kv_head_num, block_size, head_dim] kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz] block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence] """ q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) == v.size(0) - assert q.size(1) == k.size(1) == v.size(1) + assert k.size(1) == v.size(1) assert k_cache.size(-1) == v_cache.size(-1) - if head_dim >= 1024: - num_warps = 32 - elif head_dim >= 512: + if head_dim >= 512: num_warps = 16 elif head_dim >= 256: num_warps = 8 @@ -653,10 +604,12 @@ def decoding_fused_rotary_embedding( k_token_stride = k.stride(0) k_head_stride = k.stride(1) + k_head_num = k.size(1) + kv_group_num = q_head_num // k_head_num cos_token_stride = cos.stride(0) cos_stride = cos.stride(1) - grid = (triton.next_power_of_2(q_head_num), q_total_tokens) + grid = (q_head_num, q_total_tokens) decoding_fused_rotary_embedding_kernel[grid]( q, k, @@ -681,7 +634,7 @@ def decoding_fused_rotary_embedding( block_tables.stride(0), block_tables.stride(1), k_cache.size(-2), - Q_HEAD_NUM=q_head_num, + KV_GROUP_NUM=kv_group_num, HEAD_DIM=head_dim, num_warps=num_warps, ) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 7125ca386..25413a292 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -133,8 +133,9 @@ def check_spec_dec(num_layers, max_length): assert not engine.use_spec_dec assert engine.drafter is None and engine.drafter_model is None + max_new_tokens = max_length - dummy_inputs.size(1) assert len(out) == 1 - assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length + assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens # test GLIDE model glide_config = GlideLlamaConfig( @@ -152,7 +153,7 @@ def check_spec_dec(num_layers, max_length): engine.clear_spec_dec() assert len(out) == 1 - assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length + assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): @@ -186,7 +187,7 @@ def test_tp_engine(prompt_template, do_sample): @parameterize("num_layers", [1]) -@parameterize("max_length", [100]) +@parameterize("max_length", [64]) def test_spec_dec(num_layers, max_length): spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length) diff --git a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py index a7eb47a76..f641a9102 100644 --- a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py @@ -151,6 +151,16 @@ def test_flash_decoding_attention( numpy_allclose(out_ref, output, rtol=rtol, atol=atol) +try: + from vllm._C import ops as vllm_ops # noqa + + HAS_VLLM = True +except ImportError: + HAS_VLLM = False + print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm") + + +@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm") @pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) @pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) @pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32]) @@ -166,11 +176,6 @@ def test_vllm_flash_decoding_attention( torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() - try: - from vllm._C import ops as vllm_ops - except ImportError: - raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") - NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ