From e60d430cf53c9009af4682908d01742147654429 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Sun, 7 Apr 2024 14:53:30 +0800 Subject: [PATCH] [Fix] resolve conflicts of rebasing feat/speculative-decoding (#5557) - resolve conflicts of rebasing feat/speculative-decoding --- colossalai/inference/batch_bucket.py | 1 - colossalai/inference/config.py | 17 ++++++- colossalai/inference/core/engine.py | 46 +++++++++++-------- .../modeling/models/nopadding_llama.py | 12 ----- .../test_ops/triton/test_decoding_attn.py | 1 + .../test_ops/triton/test_kvcache_copy.py | 5 +- 6 files changed, 47 insertions(+), 35 deletions(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index d9aa01091..a2a2e74e8 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -97,7 +97,6 @@ class BatchBucket: @property def num_tokens_to_verify(self) -> int: - assert self.use_spec_dec and self._num_tokens_to_verify is not None return self._num_tokens_to_verify def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index b006f9828..9d7c2c0ad 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -46,6 +46,8 @@ class InputMetaData: head_dim (int, optional): Head dimension. Defaults to 32. high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False. dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32. + use_spec_dec (bool): Indicate whether to use speculative decoding. + num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True. """ block_tables: torch.Tensor = None @@ -59,9 +61,22 @@ class InputMetaData: head_dim: int = 32 high_precision: bool = False dtype: torch.dtype = torch.float32 + use_spec_dec: bool = False + num_tokens_to_verify: int = 0 def __repr__(self) -> str: - return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})" + return ( + f"InputMetaData(block_tables={self.block_tables}, " + f"sequence_lengths={self.sequence_lengths}, " + f"fd_inter_tensor={self.fd_inter_tensor}, " + f"batch_size={self.batch_size}, " + f"is_prompts={self.is_prompts}, " + f"use_cuda_kernel={self.use_cuda_kernel}, " + f"use_cuda_graph={self.use_cuda_graph}, " + f"kv_seq_len={self.kv_seq_len}, " + f"use_spec_dec={self.use_spec_dec}, " + f"num_tokens_to_verify={self.num_tokens_to_verify})" + ) @dataclass diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 032a787c3..f6b5a6e79 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -325,24 +325,29 @@ class InferenceEngine: List[Sequence]: finished sequences generated by one step. """ batch = self.request_handler.schedule() # prefill batch - assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." - input_ids = batch.get_1D_inputs() # bsz 1 for drafter model + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model # 1. Prefill small model (Drafter) - fill past kv cache for drafter model # NOTE For glide drafter models, we won't actually apply glide during prefill stage - drafter_out = self.drafter.speculate(input_ids, 1, None) + drafter_out = self.drafter.speculate(input_token_ids, 1, None) next_token_ids_spec = drafter_out.next_tokens drafter_past_key_values = drafter_out.past_key_values # 2. Prefill main model (Verifier) - fill past kv cache for main model - logits = self.model(batch, self.k_cahce, self.v_cache) + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) next_tokens = self.request_handler.search_tokens(self.generation_config, logits) # append new inputs to the batch, temporarily batch.append_batch_tokens(next_tokens) self.request_handler.allocate_batch_spec_dec(batch, 1) already_allocated_kv_len = batch.seq_lengths[0].item() - input_ids = batch.get_1D_inputs_spec_dec(1) + input_token_ids = batch.get_1D_inputs_spec_dec(1) finished_sequences = self.request_handler.update() @@ -357,13 +362,13 @@ class InferenceEngine: if self.use_glide: glide_input = GlideInput( batch.get_block_table_tensor(), - self.k_cahce[-1], # use kv cahces of the last layer + self.k_cache[-1], # use kv cahces of the last layer self.v_cache[-1], batch.get_sequence_lengths(), ) drafter_out = self.drafter.speculate( - input_ids, + input_token_ids, self.n_spec_tokens, drafter_past_key_values, glide_input=glide_input, @@ -382,7 +387,9 @@ class InferenceEngine: # 4. Decoding - Main model verifies `n` tokens in parallel if drafter_spec_length < batch.num_tokens_to_verify: batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) - logits = self.model(batch, self.k_cahce, self.v_cache) + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + next_tokens = self.request_handler.search_tokens(self.generation_config, logits) # 5. Compare and process the results @@ -402,7 +409,7 @@ class InferenceEngine: # prepare inputs for the next round of speculation n = 1 if n_matches < drafter_spec_length else 2 - input_ids = batch.get_1D_inputs_spec_dec(n) + input_token_ids = batch.get_1D_inputs_spec_dec(n) self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) finished_sequences = self.request_handler.update() @@ -564,18 +571,19 @@ class InferenceEngine: def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: input_ids = batch.get_1D_inputs() - sequence_lengths = batch.get_sequence_lengths() + if batch.is_prompts: - output_tensor = torch.zeros( - (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), - dtype=batch.dtype, - device=batch.device, - ) + n_tokens = sequence_lengths.sum().item() else: - output_tensor = torch.zeros( - (batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device - ) + n_tokens = batch.current_batch_size + if batch.use_spec_dec: + n_tokens = batch.num_tokens_to_verify + 1 + assert n_tokens == input_ids.size(0) + n_tokens = n_tokens * batch.current_batch_size + output_tensor = torch.zeros( + (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device + ) # only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph = False @@ -594,6 +602,8 @@ class InferenceEngine: kv_seq_len=sequence_lengths.max().item(), head_dim=batch.head_dim, dtype=batch.dtype, + use_spec_dec=batch.use_spec_dec, + num_tokens_to_verify=batch.num_tokens_to_verify, ) return input_ids, output_tensor, input_meta_data diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 5bffc9d12..1f0008b97 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -109,13 +109,11 @@ def llama_model_forward( # For speculative-decoding Prefill and Verifying Stage if inputmetadata.is_prompts: # output tensor shape is the same as normal Prefill Stage - o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim) rotary_indexes = [torch.arange(0, length) for length in sequence_lengths] else: # the number of tokens to be verified in parallel plus the correct token in the last step n_tokens = inputmetadata.num_tokens_to_verify + 1 assert n_tokens == hidden_states.size(0) - o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim) rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths] rotary_indexes = torch.cat(rotary_indexes, dim=-1) cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) @@ -135,15 +133,6 @@ def llama_model_forward( else: cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) - # TODO (yuanheng-zhao): revise the logic here - # if batch.is_prompts: - # output_tensor = torch.zeros( - # (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device - # ) - # else: - # output_tensor = torch.zeros( - # (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device - # ) sm_scale = 1.0 / (inputmetadata.head_dim**0.5) norm_output = torch.empty_like(hidden_states) @@ -239,7 +228,6 @@ def llama_decoder_layer_forward( sequence_lengths=sequence_lengths, cos_sin=cos_sin, fd_inter_tensor=fd_inter_tensor, - is_prompts=is_prompts, kv_seq_len=kv_seq_len, output_tensor=output_tensor, sm_scale=sm_scale, diff --git a/tests/test_infer/test_ops/triton/test_decoding_attn.py b/tests/test_infer/test_ops/triton/test_decoding_attn.py index efb8896e6..d52373128 100644 --- a/tests/test_infer/test_ops/triton/test_decoding_attn.py +++ b/tests/test_infer/test_ops/triton/test_decoding_attn.py @@ -138,5 +138,6 @@ def test_flash_decoding( assert out_torch.shape == out_triton.shape assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) + if __name__ == "__main__": test_flash_decoding(16, 32, 32, 16, 1, True) diff --git a/tests/test_infer/test_ops/triton/test_kvcache_copy.py b/tests/test_infer/test_ops/triton/test_kvcache_copy.py index 43545df79..c4122a0c7 100644 --- a/tests/test_infer/test_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer/test_ops/triton/test_kvcache_copy.py @@ -2,7 +2,6 @@ import pytest import torch from packaging import version -from colossalai.inference.modeling.layers.attention import copy_to_cache from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token @@ -28,8 +27,8 @@ def prepare_data( max_num_blocks_per_seq, same_context_len, max_seq_len, - n, - device, + n=1, + device="cuda", dtype=torch.float16, ): assert max_seq_len > n, "max_seq_len must be greater than n"