|
|
|
@ -325,24 +325,29 @@ class InferenceEngine:
|
|
|
|
|
List[Sequence]: finished sequences generated by one step. |
|
|
|
|
""" |
|
|
|
|
batch = self.request_handler.schedule() # prefill batch |
|
|
|
|
|
|
|
|
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." |
|
|
|
|
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model |
|
|
|
|
|
|
|
|
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) |
|
|
|
|
|
|
|
|
|
if input_meta_data.use_cuda_graph: |
|
|
|
|
model_executable = self.graph_runners[input_meta_data.batch_size] |
|
|
|
|
else: |
|
|
|
|
model_executable = self.model |
|
|
|
|
|
|
|
|
|
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model |
|
|
|
|
# NOTE For glide drafter models, we won't actually apply glide during prefill stage |
|
|
|
|
drafter_out = self.drafter.speculate(input_ids, 1, None) |
|
|
|
|
drafter_out = self.drafter.speculate(input_token_ids, 1, None) |
|
|
|
|
next_token_ids_spec = drafter_out.next_tokens |
|
|
|
|
drafter_past_key_values = drafter_out.past_key_values |
|
|
|
|
|
|
|
|
|
# 2. Prefill main model (Verifier) - fill past kv cache for main model |
|
|
|
|
logits = self.model(batch, self.k_cahce, self.v_cache) |
|
|
|
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |
|
|
|
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits) |
|
|
|
|
# append new inputs to the batch, temporarily |
|
|
|
|
batch.append_batch_tokens(next_tokens) |
|
|
|
|
self.request_handler.allocate_batch_spec_dec(batch, 1) |
|
|
|
|
already_allocated_kv_len = batch.seq_lengths[0].item() |
|
|
|
|
input_ids = batch.get_1D_inputs_spec_dec(1) |
|
|
|
|
input_token_ids = batch.get_1D_inputs_spec_dec(1) |
|
|
|
|
|
|
|
|
|
finished_sequences = self.request_handler.update() |
|
|
|
|
|
|
|
|
@ -357,13 +362,13 @@ class InferenceEngine:
|
|
|
|
|
if self.use_glide: |
|
|
|
|
glide_input = GlideInput( |
|
|
|
|
batch.get_block_table_tensor(), |
|
|
|
|
self.k_cahce[-1], # use kv cahces of the last layer |
|
|
|
|
self.k_cache[-1], # use kv cahces of the last layer |
|
|
|
|
self.v_cache[-1], |
|
|
|
|
batch.get_sequence_lengths(), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
drafter_out = self.drafter.speculate( |
|
|
|
|
input_ids, |
|
|
|
|
input_token_ids, |
|
|
|
|
self.n_spec_tokens, |
|
|
|
|
drafter_past_key_values, |
|
|
|
|
glide_input=glide_input, |
|
|
|
@ -382,7 +387,9 @@ class InferenceEngine:
|
|
|
|
|
# 4. Decoding - Main model verifies `n` tokens in parallel |
|
|
|
|
if drafter_spec_length < batch.num_tokens_to_verify: |
|
|
|
|
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) |
|
|
|
|
logits = self.model(batch, self.k_cahce, self.v_cache) |
|
|
|
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) |
|
|
|
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |
|
|
|
|
|
|
|
|
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits) |
|
|
|
|
|
|
|
|
|
# 5. Compare and process the results |
|
|
|
@ -402,7 +409,7 @@ class InferenceEngine:
|
|
|
|
|
|
|
|
|
|
# prepare inputs for the next round of speculation |
|
|
|
|
n = 1 if n_matches < drafter_spec_length else 2 |
|
|
|
|
input_ids = batch.get_1D_inputs_spec_dec(n) |
|
|
|
|
input_token_ids = batch.get_1D_inputs_spec_dec(n) |
|
|
|
|
|
|
|
|
|
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) |
|
|
|
|
finished_sequences = self.request_handler.update() |
|
|
|
@ -564,18 +571,19 @@ class InferenceEngine:
|
|
|
|
|
|
|
|
|
|
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: |
|
|
|
|
input_ids = batch.get_1D_inputs() |
|
|
|
|
|
|
|
|
|
sequence_lengths = batch.get_sequence_lengths() |
|
|
|
|
|
|
|
|
|
if batch.is_prompts: |
|
|
|
|
output_tensor = torch.zeros( |
|
|
|
|
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), |
|
|
|
|
dtype=batch.dtype, |
|
|
|
|
device=batch.device, |
|
|
|
|
) |
|
|
|
|
n_tokens = sequence_lengths.sum().item() |
|
|
|
|
else: |
|
|
|
|
output_tensor = torch.zeros( |
|
|
|
|
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device |
|
|
|
|
) |
|
|
|
|
n_tokens = batch.current_batch_size |
|
|
|
|
if batch.use_spec_dec: |
|
|
|
|
n_tokens = batch.num_tokens_to_verify + 1 |
|
|
|
|
assert n_tokens == input_ids.size(0) |
|
|
|
|
n_tokens = n_tokens * batch.current_batch_size |
|
|
|
|
output_tensor = torch.zeros( |
|
|
|
|
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference |
|
|
|
|
use_cuda_graph = False |
|
|
|
@ -594,6 +602,8 @@ class InferenceEngine:
|
|
|
|
|
kv_seq_len=sequence_lengths.max().item(), |
|
|
|
|
head_dim=batch.head_dim, |
|
|
|
|
dtype=batch.dtype, |
|
|
|
|
use_spec_dec=batch.use_spec_dec, |
|
|
|
|
num_tokens_to_verify=batch.num_tokens_to_verify, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return input_ids, output_tensor, input_meta_data |
|
|
|
|