[SpecDec] Fix inputs for speculation and revise past KV trimming (#5449)

* fix drafter pastkv and usage of batch bucket
feat/speculative-decoding
Yuanheng Zhao 2024-03-12 17:57:01 +08:00 committed by Yuanheng
parent a37f82629d
commit 912e24b2aa
3 changed files with 40 additions and 19 deletions

View File

@ -372,18 +372,22 @@ class BatchBucket:
seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1
def revoke_batch_tokens(self, n: int) -> None:
def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
"""Revoke the last n output tokens of the sequences in the batch
Args:
n (int): The number of output tokens to revoke from each sequence.
n_tokens (int): The number of output tokens to revoke from each sequence.
It does not count in the context tokens (input tokens).
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
For now, speculative decoding only supports batch size 1.
"""
if n >= 1:
for seq_id, seq in self._sequences_dict.items():
assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n]
self._sequence_lengths -= n
if n_tokens >= 1:
seqs_iter = iter(self._sequences_dict.items())
for _ in range(n_seqs):
seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens]
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch.

View File

@ -269,24 +269,26 @@ class InferenceEngine:
device=self.device,
dtype=self.dtype,
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_spec_dec = False
return
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.use_spec_dec:
self.disable_spec_dec()
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_spec_dec = False
return
def steps_spec_dec(self) -> List[Sequence]:
"""
@ -297,7 +299,6 @@ class InferenceEngine:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
@ -316,19 +317,19 @@ class InferenceEngine:
already_allocated_kv_len = batch.seq_lengths[0].item()
input_ids = batch.get_1D_inputs_spec_dec(1)
batch.reset_use_spec_dec() # reset batch use-spec-dec mode
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
batch.set_use_spec_dec(self.n_spec_tokens)
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
@ -343,22 +344,26 @@ class InferenceEngine:
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
input_ids = batch.get_1D_inputs_spec_dec(1)
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1)
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_ids = batch.get_1D_inputs_spec_dec(n)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
batch.reset_use_spec_dec()
return finished_sequences
def generate(

View File

@ -181,6 +181,14 @@ class RequestHandler:
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
def set_spec_dec_mode(self, n_spec_tokens: int):
self.prefill_bb.set_use_spec_dec(n_spec_tokens)
self.running_bb.set_use_spec_dec(n_spec_tokens)
def unset_spec_dec_mode(self):
self.prefill_bb.reset_use_spec_dec()
self.running_bb.reset_use_spec_dec()
def schedule(self):
"""
The main logic of request handler.
@ -208,7 +216,11 @@ class RequestHandler:
lst.remove(seq)
if self.running_list.ready_for_prefill():
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size)
# overwrite the number of sequences to add to 1 if use_spec_dec is enabled
# TODO (zhaoyuanheng): support speculative decoding for batch size > 1
if self.prefill_bb.use_spec_dec:
num_seqs_to_add = 1
for seq in self.running_list.prefill[:num_seqs_to_add]:
seq.mark_running()