diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 77cfed4df..e157a9215 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -42,6 +42,9 @@ class BatchBucket: self.device = device or get_current_device() self.dtype = dtype + self._use_spec_dec = False + self._num_tokens_to_verify = None + self._current_batch_size = 0 self._sequences_dict = dict() self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) @@ -88,6 +91,28 @@ class BatchBucket: == torch.nonzero(self._block_tables[:, 0] >= 0).numel() ) + @property + def use_spec_dec(self) -> bool: + return self._use_spec_dec + + @property + def num_tokens_to_verify(self) -> int: + assert self.use_spec_dec and self._num_tokens_to_verify is not None + return self._num_tokens_to_verify + + def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: + """Set batch bucket to use speculatvie decoding. + This will notify the adjust the lengths of inputs during modeling, + and let the main model verifies tokens in parallel. + """ + self._use_spec_dec = True + self._num_tokens_to_verify = num_tokens_to_verify + + def reset_use_spec_dec(self) -> None: + """Reset the usage of speculative decoding for the batch bucket""" + self._use_spec_dec = False + self._num_tokens_to_verify = None + def _make_compact(self) -> None: # Clean and Compress the batch based on its sequences dict. # Namely,compress sequences to the front and clean the seq lengths and block tables tensors. @@ -347,6 +372,19 @@ class BatchBucket: seq.check_finish() self._sequence_lengths[: self.current_batch_size] += 1 + def revoke_batch_tokens(self, n: int) -> None: + """Revoke the last n output tokens of the sequences in the batch + + Args: + n (int): The number of output tokens to revoke from each sequence. + It does not count in the context tokens (input tokens). + """ + if n >= 1: + for seq_id, seq in self._sequences_dict.items(): + assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence" + seq.output_token_id = seq.output_token_id[:-n] + self._sequence_lengths -= n + def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: """Clear all the sequences in the batch. @@ -401,6 +439,21 @@ class BatchBucket: return True return False + def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor: + # Used for main model verification in **Decoding Stage** + # `n` is the number of tokens to be verified, + # and so that prepare the last `n` tokens of each sequence as the inputs + assert len(self._sequences_dict) > 0, "No sequence in the batch" + assert all( + seq.output_len >= n for seq in self._sequences_dict.values() + ), "Sequence output tokens must be greater than or equal to the number of tokens to be verified." + out_li = [] + seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) + for seq_id in seq_ids: + seq: Sequence = self._sequences_dict[seq_id] + out_li.extend(seq.output_token_id[-n:]) + return torch.tensor(out_li, dtype=torch.long, device=self.device) + # For compatibility def get_1D_inputs(self) -> torch.Tensor: assert len(self._sequences_dict) > 0, "No sequence in the batch" @@ -411,8 +464,6 @@ class BatchBucket: seq.output_len == 0 for seq in self._sequences_dict.values() ), "Sequence stage (Prefill/Decoding) must be the same in the batch" out_li = [] - num_tokens = torch.sum(self._sequence_lengths) - out = torch.empty([num_tokens], dtype=torch.long) seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) for seq_id in seq_ids: seq: Sequence = self._sequences_dict[seq_id] @@ -420,6 +471,10 @@ class BatchBucket: return torch.tensor(out_li, dtype=torch.long, device=self.device) else: # Assume decoding stage + if self.use_spec_dec: + # For Speculative Decoding + # the number of tokens to be verified in parallel plus the correct token in the last step + return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1) assert all( seq.output_len > 0 for seq in self._sequences_dict.values() ), "Sequence stage (Prefill/Decoding) must be the same in the batch" diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 01b1ac53e..d0fb06c2e 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -84,6 +84,8 @@ class InferenceConfig: top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None. + n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. + glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. block_size (int): The number of blocks in a logical block, defaults to 16. tp_size (int): Tensor parallel size, defaults to 1. pp_size (int): Pipeline parallel size, defaults to 1. @@ -118,6 +120,10 @@ class InferenceConfig: top_p: Optional[float] = None min_p: Optional[float] = None + # speculative decoding configs + max_n_spec_tokens: int = 5 + glimpse_large_kv: bool = False + # paged attention configs block_size: int = 16 diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index a2388121b..672d5a959 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -12,6 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.spec import Drafter from colossalai.inference.struct import Sequence from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager @@ -52,19 +53,26 @@ class InferenceEngine: verbose: bool = False, model_policy: Policy = None, ) -> None: - assert inference_config, "Please provide inference_config." - assert tokenizer, "Please provide a tokenizer, either a defined one or str" self.inference_config = inference_config self.model_config = model.config + self.model = model self.device = torch.device("cuda") self.dtype = inference_config.dtype self.tokenizer = tokenizer self.tokenizer.pad_token = self.tokenizer.eos_token - self.generation_config = inference_config.to_generation_config(self.model_config) self.high_precision = inference_config.high_precision - model = model.eval() - model = model.cuda() - model.to(self.dtype) + self._verify_args() + + self.generation_config = inference_config.to_generation_config(self.model_config) + model.eval() + model = model.to(self.dtype) + model = model.to(self.device) + + # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` + self.use_spec_dec = False + self.drafter_model = None + self.drafter = None + self.n_spec_tokens = self.inference_config.max_n_spec_tokens if model_policy is None: if self.inference_config.pad_input: @@ -174,21 +182,18 @@ class InferenceEngine: if self.verbose: self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") - def _verify_config(self) -> None: - """ - Verify the input config - """ + def _verify_args(self) -> None: + """Verify the input args""" + if not isinstance(self.inference_config, InferenceConfig): + raise TypeError("Invalid type of inference config provided.") if not isinstance(self.model, nn.Module): raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") - if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( - self.tokenizer, PreTrainedTokenizer - ): + if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): raise TypeError( f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" ) - assert ( - self.model.__class__.__name__ in _supported_models - ), f"Model {self.model.__class__.__name__} is not supported." + if self.model.__class__.__name__ not in _supported_models: + raise ValueError(f"Model {self.model.__class__.__name__} is not supported.") def _shardformer( self, @@ -224,6 +229,138 @@ class InferenceEngine: shard_model, _ = shardformer.optimize(model, model_policy) return shard_model + def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None: + """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. + + Args: + drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. + If provided, the previous drafter and drafter model, if exist, will be overwritten. + n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. + If not provided, `max_n_spec_tokens` in InferenceConfig will be used. + + ```python + ... + engine = InferenceEngine(model, tokenizer, inference_config) + + engine.enable_spec_dec(drafter_model, n_spec_tokens=5) + engine.generate(...) # Speculative Decoding + + engine.disable_spec_dec() + engine.generate(...) # Normal generation + + engine.enable_spec_dec() + engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens + engine.clear_spec_dec() + ``` + """ + if drafter_model is None and self.drafter is None: + raise ValueError("Drafter not initialized. Please provide a Drafter Model") + if n_spec_tokens is not None: + assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens + self.n_spec_tokens = n_spec_tokens + if drafter_model is not None: + assert isinstance(drafter_model, nn.Module) + # overwrite the drafter, if exists + self.clear_spec_dec() + self.drafter_model = drafter_model + self.drafter = Drafter( + self.drafter_model, + self.tokenizer, + device=self.device, + dtype=self.dtype, + ) + # using speculative decoding for subsequent generations + self.use_spec_dec = True + + def disable_spec_dec(self) -> None: + """Disable using speculative decoding for subsequent generations.""" + # set back to the maximum number of tokens to speculate + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + self.use_spec_dec = False + return + + def clear_spec_dec(self) -> None: + """Clear relatable structures of speculative decoding, if exist.""" + if self.drafter_model or self.drafter: + self.drafter_model = None + self.drafter = None + torch.cuda.empty_cache() + self.use_spec_dec = False + return + + def steps_spec_dec(self) -> List[Sequence]: + """ + Run Speculative Decoding steps. This is like retrieving a single batch and launch inference + with many steps of speculating by a drafter model as well as verifying by a main model. + + Returns: + List[Sequence]: finished sequences generated by one step. + """ + batch = self.request_handler.schedule() # prefill batch + batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode + + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + input_ids = batch.get_1D_inputs() # bsz 1 for drafter model + + # 1. Prefill small model (Drafter) - fill past kv cache for drafter model + drafter_out = self.drafter.speculate(input_ids, 1, None) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + + # 2. Prefill main model (Verifier) - fill past kv cache for main model + logits = self.model(batch, self.k_cahce, self.v_cache) + next_tokens = self.request_handler.search_tokens(self.generation_config, logits) + # append new inputs to the batch, temporarily + batch.append_batch_tokens(next_tokens) + self.request_handler.allocate_batch_spec_dec(batch, 1) + already_allocated_kv_len = batch.seq_lengths[0].item() + input_ids = batch.get_1D_inputs_spec_dec(1) + + batch.reset_use_spec_dec() # reset batch use-spec-dec mode + finished_sequences = self.request_handler.update() + + while True: + # HACK Retrieve the running batch + # Using RequestHandler.schedule here will re-allocate same kv cache for the batch + batch = self.request_handler.running_bb # running batch + batch.set_use_spec_dec(self.n_spec_tokens) + + # 3. Decoding - Drafter model speculates `n` tokens + drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + + for next_token_id_spec in next_token_ids_spec: + self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) + cur_length = batch.seq_lengths[0].item() + if already_allocated_kv_len < cur_length: + self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) + already_allocated_kv_len = cur_length + + # 4. Decoding - Main model verifies `n` tokens in parallel + logits = self.model(batch, self.k_cahce, self.v_cache) + next_tokens = self.request_handler.search_tokens(self.generation_config, logits) + + # 5. Compare and process the results + diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) + n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() + # revoke appended tokens for each Sequence in the current batch + batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens + # append the last correct token generated by the main model + self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) + input_ids = batch.get_1D_inputs_spec_dec(1) + # trim past key values of the drafter model + drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1) + + self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) + finished_sequences = self.request_handler.update() + if len(finished_sequences) > 0: + break + + batch.reset_use_spec_dec() + + return finished_sequences + def generate( self, prompts: List[str] = None, @@ -246,7 +383,6 @@ class InferenceEngine: List[str]: Inference result returned by one generation. """ with torch.inference_mode(): - self.generation_config = generation_config if prompts is not None or prompts_token_ids is not None: self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) @@ -257,8 +393,13 @@ class InferenceEngine: if generation_config is not None: self.generation_config = generation_config - while self.request_handler.check_unfinished_seqs(): - output_seqs_list += self.step() + if self.use_spec_dec: + assert self.drafter is not None, "Drafter Model is not initialized." + while self.request_handler.check_unfinished_seqs(): + output_seqs_list += self.steps_spec_dec() + else: + while self.request_handler.check_unfinished_seqs(): + output_seqs_list += self.step() output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) @@ -428,7 +569,8 @@ class InferenceEngine: logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) if self.inference_config.pad_input: logits = logits[:, -1, :] - self.request_handler.search_tokens(self.generation_config, logits) + next_tokens = self.request_handler.search_tokens(self.generation_config, logits) + self.request_handler.append_next_tokens(next_tokens) finished_sequences = self.request_handler.update() diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 9969c6786..6c1a232e2 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -134,8 +134,12 @@ class RequestHandler: if fd_inter_tensor._tensors_initialized: fd_inter_tensor._reset() + # For Spec-Dec, process the speculated tokens plus the token in the last step for each seq + max_n_tokens = self.max_batch_size + max_n_tokens *= self.inference_config.max_n_spec_tokens + 1 + fd_inter_tensor.initialize( - max_batch_size=self.max_batch_size, + max_batch_size=max_n_tokens, num_attn_heads=model_config.num_attention_heads, kv_max_split_num=kv_max_split_num, head_dim=head_dim, @@ -230,6 +234,13 @@ class RequestHandler: return self.running_bb + def allocate_batch_spec_dec(self, batch: BatchBucket, n: int): + assert batch.use_spec_dec + if n > 0: + self.cache_manager.allocate_n_tokens_from_block_tables( + batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n + ) + def add_sequence(self, req: Sequence): """ Add the request to waiting list. @@ -282,13 +293,21 @@ class RequestHandler: return sample_tokens - def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig): + def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig): if ( - sequence.output_token_id[-1] == generation_config.eos_id - or sequence.output_len >= generation_config.max_output_len + sequence.output_token_id[-1] == generation_config.eos_token_id + or sequence.output_len >= generation_config.max_length ): sequence.mark_finished() + def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig): + for seq in batch.seqs_li: + if ( + seq.output_token_id[-1] == generation_config.eos_token_id + or seq.output_len >= generation_config.max_length + ): + seq.mark_finished() + def check_unfinished_seqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() @@ -309,9 +328,20 @@ class RequestHandler: # sample the next tokens sample_tokens = self._sample(probs, logprobs, generation_config) + return sample_tokens + + def append_next_tokens(self, sample_tokens: torch.Tensor): + assert sample_tokens.dim() == 1 + n_elements = sample_tokens.size(0) if not self.prefill_bb.is_empty: + assert ( + self.prefill_bb.current_batch_size == n_elements + ), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}" self.prefill_bb.append_batch_tokens(sample_tokens) else: + assert ( + self.running_bb.current_batch_size == n_elements + ), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}" self.running_bb.append_batch_tokens(sample_tokens) def update(self): diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 7d435d59c..2b6445d1c 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -349,6 +349,26 @@ class KVCacheManager: return seqs_to_recycle + def allocate_n_tokens_from_block_tables( + self, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + bsz: int, + n: int, + ) -> List[int]: + """Allocate logical cache blocks for `n` new tokens for a batch of sequences during decoding stage.""" + assert block_tables.dim() == 2 + assert context_lens.dim() == 1 + + bsz = block_tables.size(0) if bsz is None else bsz + assert bsz == 1, "Support bsz 1 for now" # TODO support bsz > 1 + + seqs_to_recycle = [] + for i in range(n): + seqs_to_recycle += self.allocate_tokens_from_block_tables(block_tables, context_lens - n + i + 1, bsz) + + return seqs_to_recycle + def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int: """Allocate space asked on a single block in the block table, specified by the provided position id, and updates the provided block table with the allocated block. @@ -420,9 +440,7 @@ class KVCacheManager: Returns: The remaining space required to be allocated (in other blocks). """ - assert ( - block.available_space > 0 - ), "Tried to allocate some space but found no available space left in chosen block." + assert block.available_space > 0, f"Found no available space left in the chosen block {block}." space_to_allocate = min(block.available_space, space_asked) block.allocate(space_to_allocate) return space_asked - space_to_allocate diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index c5b61385f..5bffc9d12 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -18,6 +18,7 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( context_attention_unpadded, + copy_k_to_blocked_cache, decoding_fused_rotary_embedding, flash_decoding_attention, get_xine_cache, @@ -84,9 +85,9 @@ def llama_model_forward( """This function will replace the forward function of LlamaModel. Args: - batch (BatchInfo): It stores the necessary input information for this inference. - k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache. - v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. + batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. + k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. + v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ block_tables = inputmetadata.block_tables @@ -101,7 +102,25 @@ def llama_model_forward( use_cuda_kernel = False hidden_states = self.embed_tokens(input_tokens_ids) - if use_cuda_kernel: + cu_seqlens = None + + # NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now + if inputmetadata.use_spec_dec: + # For speculative-decoding Prefill and Verifying Stage + if inputmetadata.is_prompts: + # output tensor shape is the same as normal Prefill Stage + o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim) + rotary_indexes = [torch.arange(0, length) for length in sequence_lengths] + else: + # the number of tokens to be verified in parallel plus the correct token in the last step + n_tokens = inputmetadata.num_tokens_to_verify + 1 + assert n_tokens == hidden_states.size(0) + o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim) + rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths] + rotary_indexes = torch.cat(rotary_indexes, dim=-1) + cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) + + elif use_cuda_kernel: if inputmetadata != torch.float32 and use_flash_attn2: cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) @@ -113,14 +132,22 @@ def llama_model_forward( self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts ) cos_sin = (cos, sin) - else: - cu_seqlens = None cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) + # TODO (yuanheng-zhao): revise the logic here + # if batch.is_prompts: + # output_tensor = torch.zeros( + # (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device + # ) + # else: + # output_tensor = torch.zeros( + # (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device + # ) sm_scale = 1.0 / (inputmetadata.head_dim**0.5) norm_output = torch.empty_like(hidden_states) + tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None residual = None for layer_id, decoder_layer in enumerate(self.layers): @@ -131,6 +158,8 @@ def llama_model_forward( k_cache=k_caches[layer_id], v_cache=v_caches[layer_id], is_prompts=inputmetadata.is_prompts, + is_verifier=inputmetadata.use_spec_dec, + tokens_to_verify=tokens_to_verify, sequence_lengths=sequence_lengths, cos_sin=cos_sin, fd_inter_tensor=inputmetadata.fd_inter_tensor, @@ -144,9 +173,9 @@ def llama_model_forward( ) if inputmetadata.is_prompts: - last_token_indexs = sequence_lengths.cumsum(dim=-1) - hidden_states = hidden_states[last_token_indexs - 1].contiguous() - residual = residual[last_token_indexs - 1].contiguous() + seq_len_cumsum = sequence_lengths.cumsum(dim=0) + hidden_states = hidden_states[seq_len_cumsum - 1].contiguous() + residual = residual[seq_len_cumsum - 1].contiguous() norm_output = torch.empty_like(hidden_states) hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel) @@ -164,6 +193,8 @@ def llama_decoder_layer_forward( cos_sin: Tuple[torch.Tensor], fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, + is_verifier: bool = False, + tokens_to_verify: int = None, kv_seq_len: int = 0, output_tensor: torch.Tensor = None, norm_output: torch.Tensor = None, @@ -202,6 +233,9 @@ def llama_decoder_layer_forward( block_tables=block_tables, k_cache=k_cache, v_cache=v_cache, + is_prompts=is_prompts, + is_verifier=is_verifier, + tokens_to_verify=tokens_to_verify, sequence_lengths=sequence_lengths, cos_sin=cos_sin, fd_inter_tensor=fd_inter_tensor, @@ -312,6 +346,8 @@ class NopadLlamaAttention(LlamaAttention): cos_sin: Tuple[torch.Tensor], fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, + is_verifier: bool = False, + tokens_to_verify: int = None, kv_seq_len: int = 0, output_tensor: torch.Tensor = None, sm_scale: int = None, @@ -355,7 +391,7 @@ class NopadLlamaAttention(LlamaAttention): block_size = k_cache.size(-2) if is_prompts: - if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: + if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: # flash attn 2 currently only supports FP16/BF16. inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) inference_ops.context_kv_cache_memcpy( @@ -405,17 +441,27 @@ class NopadLlamaAttention(LlamaAttention): high_precision, ) else: - decoding_fused_rotary_embedding( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - block_tables, - sequence_lengths, - ) + q_len = tokens_to_verify + 1 if is_verifier else 1 + if is_verifier: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + copy_k_to_blocked_cache( + key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len + ) + copy_k_to_blocked_cache( + value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len + ) + else: + decoding_fused_rotary_embedding( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + block_tables, + sequence_lengths, + ) attn_output = flash_decoding_attention( q=query_states, k_cache=k_cache, @@ -428,8 +474,10 @@ class NopadLlamaAttention(LlamaAttention): mid_output=fd_inter_tensor.mid_output, mid_output_lse=fd_inter_tensor.mid_output_lse, sm_scale=sm_scale, + q_len=q_len, ) + attn_output = attn_output.view(-1, self.hidden_size) attn_output = torch.mm(attn_output, self.o_proj_weight) return attn_output diff --git a/colossalai/inference/spec/drafter.py b/colossalai/inference/spec/drafter.py index 156b6d7f0..b915ea2d9 100644 --- a/colossalai/inference/spec/drafter.py +++ b/colossalai/inference/spec/drafter.py @@ -15,93 +15,75 @@ class Drafter: Args: model (nn.Module): The drafter model. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model. - max_spec_num (int): The maximum number of tokens to speculate. device (torch.device): The device for the drafter model. """ def __init__( - self, model: nn.Module, tokenizer: PreTrainedTokenizer, max_spec_num: int, device: torch.device = None + self, + model: nn.Module, + tokenizer: PreTrainedTokenizer, + device: torch.device = None, + dtype: torch.dtype = torch.float16, ): - self._drafter_model = model self._tokenizer = tokenizer - self.max_spec_num = max_spec_num - self.do_sample = False - self.sample_fn = None self._device = device or get_current_device() - self._past_key_values = None - - @property - def past_key_values(self) -> Optional[Tuple[Tuple[torch.FloatTensor]]]: - return self._past_key_values - - # Debug usage for now - @property - def past_key_values_shape(self): - if self._past_key_values is None: - return [] - return self._past_key_values[0][0].shape + self._dtype = dtype + self._drafter_model = model.to(self._device) + self._drafter_model = model.to(self._dtype) + self._drafter_model.eval() def get_model(self) -> nn.Module: return self._drafter_model - def reset_sample_method(self, sample_fn: callable) -> None: - self.do_sample = True - self.sample_fn = sample_fn + @staticmethod + def trim_kv_cache( + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int + ) -> Tuple[Tuple[torch.FloatTensor]]: + """Trim the last `invalid_token_num` kv caches. - def clear_sample_method(self) -> None: - self.do_sample = False - self.sample_fn = None + past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape + num_layers x 2 x (bsz x num_heads x seq_len x head_dim) + invalid_token_num (int): The number of invalid tokens to trim. + """ + if past_key_values is None or invalid_token_num < 1: + return past_key_values - def reset_max_spec_num(self, n: int) -> None: - assert isinstance(n, int) and n > 1 - self.max_spec_num = n - - def reset_past_key_values(self, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None) -> None: - self._past_key_values = past_key_values - - def trim_kv_cache(self, invalid_token_num) -> Tuple[Tuple[torch.FloatTensor]]: - # Tuple of kv cache tensors: num_layers x 2 x (bsz x num_heads x seq_len x head_dim) - # Trim the last `invalid_token_num` kv caches - # The verifier (main model) might reject `invalid_token_num` tokens, - # and so that we have to trim the invalid tokens for the kv cache of the drafter model. - assert self._past_key_values is not None trimmed_past_key_values = [] - for layer_idx in range(len(self._past_key_values)): - past_key_value = self._past_key_values[layer_idx] + for layer_idx in range(len(past_key_values)): + past_key_value = past_key_values[layer_idx] trimmed_past_key_values.append( ( past_key_value[0][:, :, :-invalid_token_num, :], past_key_value[1][:, :, :-invalid_token_num, :], ) ) - self._past_key_values = tuple(trimmed_past_key_values) - return self._past_key_values + past_key_values = tuple(trimmed_past_key_values) + return past_key_values @torch.inference_mode() def speculate( - self, input_ids: torch.Tensor, n: int, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None + self, + input_ids: torch.Tensor, + n_spec_tokens: int, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, ) -> DrafterOutput: - """Generate n tokens using the drafter model. + """Generate n_spec_tokens tokens using the drafter model. Args: input_ids (torch.Tensor): Input token ids. - n (int): Number of tokens to speculate. + n_spec_tokens (int): Number of tokens to speculate. past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence. """ + assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate" - assert 0 <= n <= self.max_spec_num, f"Invalid number {n} to speculate" - - # FIXME For compatibility with transformers 4.36.2 (versions before 4.38.0) + # For compatibility with transformers of versions before 4.38.0 if input_ids.dim() == 1: input_ids = input_ids.unsqueeze(0) - if past_key_values is None: - past_key_values = self._past_key_values - logits = [] token_ids = [] - for _ in range(n): + for _ in range(n_spec_tokens): outputs = self._drafter_model( input_ids, return_dict=True, @@ -110,17 +92,10 @@ class Drafter: ) next_token_logits = outputs.logits[:, -1, :] - # Skip logits_processor for drafter model - - # Sample - if self.do_sample: - if self.sample_fn is not None: - probs = self.sample_fn(next_token_logits) - else: - probs = nn.functional.softmax(next_token_logits, dim=-1) - next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_token_ids = torch.argmax(next_token_logits, dim=-1) + # NOTE Only use greedy search for speculating. + # As the drafter model usually has only a few layers with few parameters, + # introducing sampling will make the speculation unstable and lead to worse performance. + next_token_ids = torch.argmax(next_token_logits, dim=-1) logits.append(next_token_logits) token_ids.append(next_token_ids) @@ -133,8 +108,6 @@ class Drafter: speculated_length = len(token_ids) # TODO For now, only support bsz 1 logits = torch.concat(logits, dim=0) token_ids = torch.concat(token_ids, dim=-1) - # update past_key_values - self._past_key_values = past_key_values out = DrafterOutput( speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index e1ccffe53..dcbad7bc8 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -44,6 +44,7 @@ def _flash_decoding_fwd_kernel( cur_seq_idx = cur_token_idx // q_len if cur_seq_idx >= batch_size: return + cur_token_off = (cur_token_idx % q_len) - q_len + 1 cur_head_idx = tl.program_id(1) block_start_kv = tl.program_id(2) # for splitting k/v @@ -52,7 +53,8 @@ def _flash_decoding_fwd_kernel( # and then support calculating multiple kv cache blocks on an instance tl.static_assert(BLOCK_KV == BLOCK_SIZE) # get the current (kv) sequence length - cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + # cur_token_off is used as a "mask" here for spec-dec during verification process + cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off if block_start_kv * BLOCK_KV >= cur_kv_seq_len: return @@ -150,7 +152,9 @@ def _flash_decoding_fwd_reduce_kernel( return cur_head_idx = tl.program_id(1) - cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + # cur_token_off is used as a "mask" here for spec-dec during verification process + cur_token_off = (cur_token_idx % q_len) - q_len + 1 + cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off offsets_dmodel = tl.arange(0, HEAD_DIM) # NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have diff --git a/tests/test_infer/test_drafter.py b/tests/test_infer/test_drafter.py index d1728ecfc..e0d63a294 100644 --- a/tests/test_infer/test_drafter.py +++ b/tests/test_infer/test_drafter.py @@ -2,10 +2,15 @@ import pytest import torch from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM +import colossalai +from colossalai.inference.config import GenerationConfig, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.spec.drafter import Drafter +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device NUM_LAYERS = 2 +MAX_LEN = 100 @pytest.mark.parametrize("spec_num", [5]) @@ -14,13 +19,13 @@ def test_drafter(spec_num: int): device = get_current_device() + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS) - toy_config.pad_token_id = toy_config.eos_token_id + toy_config.pad_token_id = tokenizer.eos_token_id drafter_model = LlamaForCausalLM(toy_config) drafter_model = drafter_model.eval().cuda() - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - drafter = Drafter(drafter_model, tokenizer, spec_num, device=device) + drafter = Drafter(drafter_model, tokenizer, device=device) input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device) out = drafter.speculate(input_ids, spec_num) @@ -29,13 +34,75 @@ def test_drafter(spec_num: int): assert out.speculated_length == spec_num assert out.next_tokens.shape == (spec_num,) assert out.logits.shape == (spec_num, len(tokenizer)) - assert drafter._past_key_values[0][0].size(2) == out.past_key_values[0][0].size(2) == past_kv_length + assert out.past_key_values[0][0].size(2) == past_kv_length - reject_num = 3 - assert reject_num <= spec_num - drafter.trim_kv_cache(reject_num) - assert drafter._past_key_values[0][0].size(2) == past_kv_length - reject_num + reject_num = max(0, spec_num - 1) + trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num) + assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num + + +def check_sd(): + torch.manual_seed(123) + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + # Dummy configs for testing + toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS) + toy_config.pad_token_id = tokenizer.eos_token_id + drafter_model = LlamaForCausalLM(toy_config) + drafter_model = drafter_model.eval().cuda() + large_config = LlamaConfig( + hidden_size=4096, + intermediate_size=11008, + num_attention_heads=32, + num_hidden_layers=8, + num_key_value_heads=32, + max_position_embeddings=2048, + ) + large_config.pad_token_id = tokenizer.eos_token_id + main_model = LlamaForCausalLM(large_config) + + inference_config = InferenceConfig( + dtype="fp16", + micro_batch_size=1, + max_batch_size=1, + max_input_len=128, + max_output_len=128, + prefill_ratio=1.2, + block_size=16, + ) + engine = InferenceEngine(main_model, tokenizer, inference_config) + engine.enable_spec_dec(drafter_model, n_spec_tokens=5) + + dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda") + generation_config = GenerationConfig( + pad_token_id=tokenizer.eos_token_id, + max_length=MAX_LEN, + eos_token_id=tokenizer.eos_token_id, + ) + out, out_token_ids = engine.generate( + prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True + ) + engine.disable_spec_dec() + engine.clear_spec_dec() + + assert not engine.use_spec_dec + assert engine.drafter is None and engine.drafter_model is None + + assert len(out) == 1 + assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_sd() + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_spec_dec(): + spawn(run_dist, nprocs=1) if __name__ == "__main__": test_drafter(spec_num=5) + test_spec_dec() diff --git a/tests/test_infer/test_ops/triton/kernel_utils.py b/tests/test_infer/test_ops/triton/kernel_utils.py index f1ae45477..7ae5a833b 100644 --- a/tests/test_infer/test_ops/triton/kernel_utils.py +++ b/tests/test_infer/test_ops/triton/kernel_utils.py @@ -19,12 +19,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim) -def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"): +def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"): + assert q_len <= kv_len + + causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1) + padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device) for i in range(bsz): cur_seq_len = kv_lengths[i].item() assert cur_seq_len <= kv_len padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf") + + padding_mask[:, :, -q_len:, -q_len:] += causal_mask + return padding_mask @@ -56,11 +63,13 @@ def torch_attn_ref( attn_scores = qk / (head_dim**0.5) assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores" - # for left-side padding - if attention_mask.size() != (bsz, 1, q_len, kv_len): - raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}") + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}" + ) + attn_scores = attn_scores + attention_mask - attn_scores = attn_scores + attention_mask attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype) out = torch.matmul(attn_weights, v) if out.size() != (bsz, num_heads, q_len, head_dim): diff --git a/tests/test_infer/test_ops/triton/test_decoding_attn.py b/tests/test_infer/test_ops/triton/test_decoding_attn.py index 77354e1bb..efb8896e6 100644 --- a/tests/test_infer/test_ops/triton/test_decoding_attn.py +++ b/tests/test_infer/test_ops/triton/test_decoding_attn.py @@ -6,8 +6,8 @@ from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import ( convert_kv_unpad_to_padded, + create_attention_mask, generate_caches_and_block_tables_v2, - prepare_padding_mask, torch_attn_ref, ) @@ -91,9 +91,9 @@ def test_flash_decoding( k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b) v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b) - torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device) + attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device) out_torch = torch_attn_ref( - q, k_torch, v_torch, torch_padding_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM + q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM ) k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( @@ -138,6 +138,5 @@ def test_flash_decoding( assert out_torch.shape == out_triton.shape assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) - if __name__ == "__main__": test_flash_decoding(16, 32, 32, 16, 1, True)