mirror of https://github.com/hpcaitech/ColossalAI
[SpecDec] Fix inputs for speculation and revise past KV trimming (#5449)
* fix drafter pastkv and usage of batch bucketfeat/speculative-decoding
parent
a37f82629d
commit
912e24b2aa
|
@ -372,18 +372,22 @@ class BatchBucket:
|
||||||
seq.check_finish()
|
seq.check_finish()
|
||||||
self._sequence_lengths[: self.current_batch_size] += 1
|
self._sequence_lengths[: self.current_batch_size] += 1
|
||||||
|
|
||||||
def revoke_batch_tokens(self, n: int) -> None:
|
def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
|
||||||
"""Revoke the last n output tokens of the sequences in the batch
|
"""Revoke the last n output tokens of the sequences in the batch
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n (int): The number of output tokens to revoke from each sequence.
|
n_tokens (int): The number of output tokens to revoke from each sequence.
|
||||||
It does not count in the context tokens (input tokens).
|
It does not count in the context tokens (input tokens).
|
||||||
|
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
|
||||||
|
For now, speculative decoding only supports batch size 1.
|
||||||
"""
|
"""
|
||||||
if n >= 1:
|
if n_tokens >= 1:
|
||||||
for seq_id, seq in self._sequences_dict.items():
|
seqs_iter = iter(self._sequences_dict.items())
|
||||||
assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence"
|
for _ in range(n_seqs):
|
||||||
seq.output_token_id = seq.output_token_id[:-n]
|
seq_id, seq = next(seqs_iter)
|
||||||
self._sequence_lengths -= n
|
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
|
||||||
|
seq.output_token_id = seq.output_token_id[:-n_tokens]
|
||||||
|
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
|
||||||
|
|
||||||
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
||||||
"""Clear all the sequences in the batch.
|
"""Clear all the sequences in the batch.
|
||||||
|
|
|
@ -269,24 +269,26 @@ class InferenceEngine:
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||||
# using speculative decoding for subsequent generations
|
# using speculative decoding for subsequent generations
|
||||||
self.use_spec_dec = True
|
self.use_spec_dec = True
|
||||||
|
|
||||||
def disable_spec_dec(self) -> None:
|
def disable_spec_dec(self) -> None:
|
||||||
"""Disable using speculative decoding for subsequent generations."""
|
"""Disable using speculative decoding for subsequent generations."""
|
||||||
|
self.request_handler.unset_spec_dec_mode()
|
||||||
# set back to the maximum number of tokens to speculate
|
# set back to the maximum number of tokens to speculate
|
||||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
self.use_spec_dec = False
|
self.use_spec_dec = False
|
||||||
return
|
|
||||||
|
|
||||||
def clear_spec_dec(self) -> None:
|
def clear_spec_dec(self) -> None:
|
||||||
"""Clear relatable structures of speculative decoding, if exist."""
|
"""Clear relatable structures of speculative decoding, if exist."""
|
||||||
|
if self.use_spec_dec:
|
||||||
|
self.disable_spec_dec()
|
||||||
if self.drafter_model or self.drafter:
|
if self.drafter_model or self.drafter:
|
||||||
self.drafter_model = None
|
self.drafter_model = None
|
||||||
self.drafter = None
|
self.drafter = None
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.use_spec_dec = False
|
self.use_spec_dec = False
|
||||||
return
|
|
||||||
|
|
||||||
def steps_spec_dec(self) -> List[Sequence]:
|
def steps_spec_dec(self) -> List[Sequence]:
|
||||||
"""
|
"""
|
||||||
|
@ -297,7 +299,6 @@ class InferenceEngine:
|
||||||
List[Sequence]: finished sequences generated by one step.
|
List[Sequence]: finished sequences generated by one step.
|
||||||
"""
|
"""
|
||||||
batch = self.request_handler.schedule() # prefill batch
|
batch = self.request_handler.schedule() # prefill batch
|
||||||
batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode
|
|
||||||
|
|
||||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||||
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
|
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
|
||||||
|
@ -316,19 +317,19 @@ class InferenceEngine:
|
||||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||||
input_ids = batch.get_1D_inputs_spec_dec(1)
|
input_ids = batch.get_1D_inputs_spec_dec(1)
|
||||||
|
|
||||||
batch.reset_use_spec_dec() # reset batch use-spec-dec mode
|
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# HACK Retrieve the running batch
|
# HACK Retrieve the running batch
|
||||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||||
batch = self.request_handler.running_bb # running batch
|
batch = self.request_handler.running_bb # running batch
|
||||||
batch.set_use_spec_dec(self.n_spec_tokens)
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||||
|
|
||||||
# 3. Decoding - Drafter model speculates `n` tokens
|
# 3. Decoding - Drafter model speculates `n` tokens
|
||||||
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
|
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
|
||||||
next_token_ids_spec = drafter_out.next_tokens
|
next_token_ids_spec = drafter_out.next_tokens
|
||||||
drafter_past_key_values = drafter_out.past_key_values
|
drafter_past_key_values = drafter_out.past_key_values
|
||||||
|
drafter_spec_length = drafter_out.speculated_length
|
||||||
|
|
||||||
for next_token_id_spec in next_token_ids_spec:
|
for next_token_id_spec in next_token_ids_spec:
|
||||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||||
|
@ -343,22 +344,26 @@ class InferenceEngine:
|
||||||
|
|
||||||
# 5. Compare and process the results
|
# 5. Compare and process the results
|
||||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||||
n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||||
|
|
||||||
# revoke appended tokens for each Sequence in the current batch
|
# revoke appended tokens for each Sequence in the current batch
|
||||||
batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens
|
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||||
# append the last correct token generated by the main model
|
# append the last correct token generated by the main model
|
||||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||||
input_ids = batch.get_1D_inputs_spec_dec(1)
|
|
||||||
# trim past key values of the drafter model
|
# trim past key values of the drafter model
|
||||||
drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1)
|
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||||
|
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||||
|
)
|
||||||
|
# prepare inputs for the next round of speculation
|
||||||
|
n = 1 if n_matches < drafter_spec_length else 2
|
||||||
|
input_ids = batch.get_1D_inputs_spec_dec(n)
|
||||||
|
|
||||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
if len(finished_sequences) > 0:
|
if len(finished_sequences) > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
batch.reset_use_spec_dec()
|
|
||||||
|
|
||||||
return finished_sequences
|
return finished_sequences
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
|
|
@ -181,6 +181,14 @@ class RequestHandler:
|
||||||
def get_kvcache(self):
|
def get_kvcache(self):
|
||||||
return self.cache_manager.get_kv_cache()
|
return self.cache_manager.get_kv_cache()
|
||||||
|
|
||||||
|
def set_spec_dec_mode(self, n_spec_tokens: int):
|
||||||
|
self.prefill_bb.set_use_spec_dec(n_spec_tokens)
|
||||||
|
self.running_bb.set_use_spec_dec(n_spec_tokens)
|
||||||
|
|
||||||
|
def unset_spec_dec_mode(self):
|
||||||
|
self.prefill_bb.reset_use_spec_dec()
|
||||||
|
self.running_bb.reset_use_spec_dec()
|
||||||
|
|
||||||
def schedule(self):
|
def schedule(self):
|
||||||
"""
|
"""
|
||||||
The main logic of request handler.
|
The main logic of request handler.
|
||||||
|
@ -208,7 +216,11 @@ class RequestHandler:
|
||||||
lst.remove(seq)
|
lst.remove(seq)
|
||||||
|
|
||||||
if self.running_list.ready_for_prefill():
|
if self.running_list.ready_for_prefill():
|
||||||
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
|
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size)
|
||||||
|
# overwrite the number of sequences to add to 1 if use_spec_dec is enabled
|
||||||
|
# TODO (zhaoyuanheng): support speculative decoding for batch size > 1
|
||||||
|
if self.prefill_bb.use_spec_dec:
|
||||||
|
num_seqs_to_add = 1
|
||||||
|
|
||||||
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
||||||
seq.mark_running()
|
seq.mark_running()
|
||||||
|
|
Loading…
Reference in New Issue