mirror of https://github.com/hpcaitech/ColossalAI
[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)
* fix flash decoding mask during verification * add spec-dec * add test for spec-dec * revise drafter init * remove drafter sampling * retire past kv in drafter * (trivial) rename attrs * (trivial) rename arg * revise how we enable/disable spec-decfeat/speculative-decoding
parent
5a9b05f7b2
commit
a37f82629d
|
@ -42,6 +42,9 @@ class BatchBucket:
|
||||||
self.device = device or get_current_device()
|
self.device = device or get_current_device()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self._use_spec_dec = False
|
||||||
|
self._num_tokens_to_verify = None
|
||||||
|
|
||||||
self._current_batch_size = 0
|
self._current_batch_size = 0
|
||||||
self._sequences_dict = dict()
|
self._sequences_dict = dict()
|
||||||
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
||||||
|
@ -88,6 +91,28 @@ class BatchBucket:
|
||||||
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
|
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_spec_dec(self) -> bool:
|
||||||
|
return self._use_spec_dec
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_tokens_to_verify(self) -> int:
|
||||||
|
assert self.use_spec_dec and self._num_tokens_to_verify is not None
|
||||||
|
return self._num_tokens_to_verify
|
||||||
|
|
||||||
|
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
|
||||||
|
"""Set batch bucket to use speculatvie decoding.
|
||||||
|
This will notify the adjust the lengths of inputs during modeling,
|
||||||
|
and let the main model verifies tokens in parallel.
|
||||||
|
"""
|
||||||
|
self._use_spec_dec = True
|
||||||
|
self._num_tokens_to_verify = num_tokens_to_verify
|
||||||
|
|
||||||
|
def reset_use_spec_dec(self) -> None:
|
||||||
|
"""Reset the usage of speculative decoding for the batch bucket"""
|
||||||
|
self._use_spec_dec = False
|
||||||
|
self._num_tokens_to_verify = None
|
||||||
|
|
||||||
def _make_compact(self) -> None:
|
def _make_compact(self) -> None:
|
||||||
# Clean and Compress the batch based on its sequences dict.
|
# Clean and Compress the batch based on its sequences dict.
|
||||||
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
|
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
|
||||||
|
@ -347,6 +372,19 @@ class BatchBucket:
|
||||||
seq.check_finish()
|
seq.check_finish()
|
||||||
self._sequence_lengths[: self.current_batch_size] += 1
|
self._sequence_lengths[: self.current_batch_size] += 1
|
||||||
|
|
||||||
|
def revoke_batch_tokens(self, n: int) -> None:
|
||||||
|
"""Revoke the last n output tokens of the sequences in the batch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number of output tokens to revoke from each sequence.
|
||||||
|
It does not count in the context tokens (input tokens).
|
||||||
|
"""
|
||||||
|
if n >= 1:
|
||||||
|
for seq_id, seq in self._sequences_dict.items():
|
||||||
|
assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence"
|
||||||
|
seq.output_token_id = seq.output_token_id[:-n]
|
||||||
|
self._sequence_lengths -= n
|
||||||
|
|
||||||
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
||||||
"""Clear all the sequences in the batch.
|
"""Clear all the sequences in the batch.
|
||||||
|
|
||||||
|
@ -401,6 +439,21 @@ class BatchBucket:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
|
||||||
|
# Used for main model verification in **Decoding Stage**
|
||||||
|
# `n` is the number of tokens to be verified,
|
||||||
|
# and so that prepare the last `n` tokens of each sequence as the inputs
|
||||||
|
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||||
|
assert all(
|
||||||
|
seq.output_len >= n for seq in self._sequences_dict.values()
|
||||||
|
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
|
||||||
|
out_li = []
|
||||||
|
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
|
out_li.extend(seq.output_token_id[-n:])
|
||||||
|
return torch.tensor(out_li, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
# For compatibility
|
# For compatibility
|
||||||
def get_1D_inputs(self) -> torch.Tensor:
|
def get_1D_inputs(self) -> torch.Tensor:
|
||||||
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||||
|
@ -411,8 +464,6 @@ class BatchBucket:
|
||||||
seq.output_len == 0 for seq in self._sequences_dict.values()
|
seq.output_len == 0 for seq in self._sequences_dict.values()
|
||||||
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||||
out_li = []
|
out_li = []
|
||||||
num_tokens = torch.sum(self._sequence_lengths)
|
|
||||||
out = torch.empty([num_tokens], dtype=torch.long)
|
|
||||||
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq: Sequence = self._sequences_dict[seq_id]
|
seq: Sequence = self._sequences_dict[seq_id]
|
||||||
|
@ -420,6 +471,10 @@ class BatchBucket:
|
||||||
return torch.tensor(out_li, dtype=torch.long, device=self.device)
|
return torch.tensor(out_li, dtype=torch.long, device=self.device)
|
||||||
else:
|
else:
|
||||||
# Assume decoding stage
|
# Assume decoding stage
|
||||||
|
if self.use_spec_dec:
|
||||||
|
# For Speculative Decoding
|
||||||
|
# the number of tokens to be verified in parallel plus the correct token in the last step
|
||||||
|
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
|
||||||
assert all(
|
assert all(
|
||||||
seq.output_len > 0 for seq in self._sequences_dict.values()
|
seq.output_len > 0 for seq in self._sequences_dict.values()
|
||||||
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||||
|
|
|
@ -84,6 +84,8 @@ class InferenceConfig:
|
||||||
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
||||||
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
||||||
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
|
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
|
||||||
|
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||||
|
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||||
tp_size (int): Tensor parallel size, defaults to 1.
|
tp_size (int): Tensor parallel size, defaults to 1.
|
||||||
pp_size (int): Pipeline parallel size, defaults to 1.
|
pp_size (int): Pipeline parallel size, defaults to 1.
|
||||||
|
@ -118,6 +120,10 @@ class InferenceConfig:
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
min_p: Optional[float] = None
|
min_p: Optional[float] = None
|
||||||
|
|
||||||
|
# speculative decoding configs
|
||||||
|
max_n_spec_tokens: int = 5
|
||||||
|
glimpse_large_kv: bool = False
|
||||||
|
|
||||||
# paged attention configs
|
# paged attention configs
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||||
from colossalai.inference.modeling.policy import model_policy_map
|
from colossalai.inference.modeling.policy import model_policy_map
|
||||||
|
from colossalai.inference.spec import Drafter
|
||||||
from colossalai.inference.struct import Sequence
|
from colossalai.inference.struct import Sequence
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
@ -52,19 +53,26 @@ class InferenceEngine:
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
model_policy: Policy = None,
|
model_policy: Policy = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert inference_config, "Please provide inference_config."
|
|
||||||
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
|
|
||||||
self.inference_config = inference_config
|
self.inference_config = inference_config
|
||||||
self.model_config = model.config
|
self.model_config = model.config
|
||||||
|
self.model = model
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
self.dtype = inference_config.dtype
|
self.dtype = inference_config.dtype
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
|
||||||
self.high_precision = inference_config.high_precision
|
self.high_precision = inference_config.high_precision
|
||||||
model = model.eval()
|
self._verify_args()
|
||||||
model = model.cuda()
|
|
||||||
model.to(self.dtype)
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||||
|
model.eval()
|
||||||
|
model = model.to(self.dtype)
|
||||||
|
model = model.to(self.device)
|
||||||
|
|
||||||
|
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||||
|
self.use_spec_dec = False
|
||||||
|
self.drafter_model = None
|
||||||
|
self.drafter = None
|
||||||
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
|
|
||||||
if model_policy is None:
|
if model_policy is None:
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
|
@ -174,21 +182,18 @@ class InferenceEngine:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||||
|
|
||||||
def _verify_config(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
"""
|
"""Verify the input args"""
|
||||||
Verify the input config
|
if not isinstance(self.inference_config, InferenceConfig):
|
||||||
"""
|
raise TypeError("Invalid type of inference config provided.")
|
||||||
if not isinstance(self.model, nn.Module):
|
if not isinstance(self.model, nn.Module):
|
||||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||||
self.tokenizer, PreTrainedTokenizer
|
|
||||||
):
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||||
)
|
)
|
||||||
assert (
|
if self.model.__class__.__name__ not in _supported_models:
|
||||||
self.model.__class__.__name__ in _supported_models
|
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
|
||||||
), f"Model {self.model.__class__.__name__} is not supported."
|
|
||||||
|
|
||||||
def _shardformer(
|
def _shardformer(
|
||||||
self,
|
self,
|
||||||
|
@ -224,6 +229,138 @@ class InferenceEngine:
|
||||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||||
return shard_model
|
return shard_model
|
||||||
|
|
||||||
|
def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
|
||||||
|
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||||
|
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||||
|
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||||
|
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||||
|
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||||
|
|
||||||
|
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||||
|
engine.generate(...) # Speculative Decoding
|
||||||
|
|
||||||
|
engine.disable_spec_dec()
|
||||||
|
engine.generate(...) # Normal generation
|
||||||
|
|
||||||
|
engine.enable_spec_dec()
|
||||||
|
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||||
|
engine.clear_spec_dec()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if drafter_model is None and self.drafter is None:
|
||||||
|
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||||
|
if n_spec_tokens is not None:
|
||||||
|
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||||
|
self.n_spec_tokens = n_spec_tokens
|
||||||
|
if drafter_model is not None:
|
||||||
|
assert isinstance(drafter_model, nn.Module)
|
||||||
|
# overwrite the drafter, if exists
|
||||||
|
self.clear_spec_dec()
|
||||||
|
self.drafter_model = drafter_model
|
||||||
|
self.drafter = Drafter(
|
||||||
|
self.drafter_model,
|
||||||
|
self.tokenizer,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
# using speculative decoding for subsequent generations
|
||||||
|
self.use_spec_dec = True
|
||||||
|
|
||||||
|
def disable_spec_dec(self) -> None:
|
||||||
|
"""Disable using speculative decoding for subsequent generations."""
|
||||||
|
# set back to the maximum number of tokens to speculate
|
||||||
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||||
|
self.use_spec_dec = False
|
||||||
|
return
|
||||||
|
|
||||||
|
def clear_spec_dec(self) -> None:
|
||||||
|
"""Clear relatable structures of speculative decoding, if exist."""
|
||||||
|
if self.drafter_model or self.drafter:
|
||||||
|
self.drafter_model = None
|
||||||
|
self.drafter = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.use_spec_dec = False
|
||||||
|
return
|
||||||
|
|
||||||
|
def steps_spec_dec(self) -> List[Sequence]:
|
||||||
|
"""
|
||||||
|
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||||
|
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Sequence]: finished sequences generated by one step.
|
||||||
|
"""
|
||||||
|
batch = self.request_handler.schedule() # prefill batch
|
||||||
|
batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode
|
||||||
|
|
||||||
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||||
|
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
|
||||||
|
|
||||||
|
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||||
|
drafter_out = self.drafter.speculate(input_ids, 1, None)
|
||||||
|
next_token_ids_spec = drafter_out.next_tokens
|
||||||
|
drafter_past_key_values = drafter_out.past_key_values
|
||||||
|
|
||||||
|
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||||
|
logits = self.model(batch, self.k_cahce, self.v_cache)
|
||||||
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
# append new inputs to the batch, temporarily
|
||||||
|
batch.append_batch_tokens(next_tokens)
|
||||||
|
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||||
|
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||||
|
input_ids = batch.get_1D_inputs_spec_dec(1)
|
||||||
|
|
||||||
|
batch.reset_use_spec_dec() # reset batch use-spec-dec mode
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# HACK Retrieve the running batch
|
||||||
|
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||||
|
batch = self.request_handler.running_bb # running batch
|
||||||
|
batch.set_use_spec_dec(self.n_spec_tokens)
|
||||||
|
|
||||||
|
# 3. Decoding - Drafter model speculates `n` tokens
|
||||||
|
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
|
||||||
|
next_token_ids_spec = drafter_out.next_tokens
|
||||||
|
drafter_past_key_values = drafter_out.past_key_values
|
||||||
|
|
||||||
|
for next_token_id_spec in next_token_ids_spec:
|
||||||
|
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||||
|
cur_length = batch.seq_lengths[0].item()
|
||||||
|
if already_allocated_kv_len < cur_length:
|
||||||
|
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||||
|
already_allocated_kv_len = cur_length
|
||||||
|
|
||||||
|
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||||
|
logits = self.model(batch, self.k_cahce, self.v_cache)
|
||||||
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
|
||||||
|
# 5. Compare and process the results
|
||||||
|
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||||
|
n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||||
|
# revoke appended tokens for each Sequence in the current batch
|
||||||
|
batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens
|
||||||
|
# append the last correct token generated by the main model
|
||||||
|
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||||
|
input_ids = batch.get_1D_inputs_spec_dec(1)
|
||||||
|
# trim past key values of the drafter model
|
||||||
|
drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1)
|
||||||
|
|
||||||
|
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
if len(finished_sequences) > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch.reset_use_spec_dec()
|
||||||
|
|
||||||
|
return finished_sequences
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str] = None,
|
prompts: List[str] = None,
|
||||||
|
@ -246,7 +383,6 @@ class InferenceEngine:
|
||||||
List[str]: Inference result returned by one generation.
|
List[str]: Inference result returned by one generation.
|
||||||
"""
|
"""
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.generation_config = generation_config
|
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||||
|
|
||||||
|
@ -257,6 +393,11 @@ class InferenceEngine:
|
||||||
if generation_config is not None:
|
if generation_config is not None:
|
||||||
self.generation_config = generation_config
|
self.generation_config = generation_config
|
||||||
|
|
||||||
|
if self.use_spec_dec:
|
||||||
|
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||||
|
while self.request_handler.check_unfinished_seqs():
|
||||||
|
output_seqs_list += self.steps_spec_dec()
|
||||||
|
else:
|
||||||
while self.request_handler.check_unfinished_seqs():
|
while self.request_handler.check_unfinished_seqs():
|
||||||
output_seqs_list += self.step()
|
output_seqs_list += self.step()
|
||||||
|
|
||||||
|
@ -428,7 +569,8 @@ class InferenceEngine:
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
|
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
|
|
@ -134,8 +134,12 @@ class RequestHandler:
|
||||||
if fd_inter_tensor._tensors_initialized:
|
if fd_inter_tensor._tensors_initialized:
|
||||||
fd_inter_tensor._reset()
|
fd_inter_tensor._reset()
|
||||||
|
|
||||||
|
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
|
||||||
|
max_n_tokens = self.max_batch_size
|
||||||
|
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
|
||||||
|
|
||||||
fd_inter_tensor.initialize(
|
fd_inter_tensor.initialize(
|
||||||
max_batch_size=self.max_batch_size,
|
max_batch_size=max_n_tokens,
|
||||||
num_attn_heads=model_config.num_attention_heads,
|
num_attn_heads=model_config.num_attention_heads,
|
||||||
kv_max_split_num=kv_max_split_num,
|
kv_max_split_num=kv_max_split_num,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
|
@ -230,6 +234,13 @@ class RequestHandler:
|
||||||
|
|
||||||
return self.running_bb
|
return self.running_bb
|
||||||
|
|
||||||
|
def allocate_batch_spec_dec(self, batch: BatchBucket, n: int):
|
||||||
|
assert batch.use_spec_dec
|
||||||
|
if n > 0:
|
||||||
|
self.cache_manager.allocate_n_tokens_from_block_tables(
|
||||||
|
batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n
|
||||||
|
)
|
||||||
|
|
||||||
def add_sequence(self, req: Sequence):
|
def add_sequence(self, req: Sequence):
|
||||||
"""
|
"""
|
||||||
Add the request to waiting list.
|
Add the request to waiting list.
|
||||||
|
@ -282,13 +293,21 @@ class RequestHandler:
|
||||||
|
|
||||||
return sample_tokens
|
return sample_tokens
|
||||||
|
|
||||||
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||||
if (
|
if (
|
||||||
sequence.output_token_id[-1] == generation_config.eos_id
|
sequence.output_token_id[-1] == generation_config.eos_token_id
|
||||||
or sequence.output_len >= generation_config.max_output_len
|
or sequence.output_len >= generation_config.max_length
|
||||||
):
|
):
|
||||||
sequence.mark_finished()
|
sequence.mark_finished()
|
||||||
|
|
||||||
|
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
|
||||||
|
for seq in batch.seqs_li:
|
||||||
|
if (
|
||||||
|
seq.output_token_id[-1] == generation_config.eos_token_id
|
||||||
|
or seq.output_len >= generation_config.max_length
|
||||||
|
):
|
||||||
|
seq.mark_finished()
|
||||||
|
|
||||||
def check_unfinished_seqs(self) -> bool:
|
def check_unfinished_seqs(self) -> bool:
|
||||||
return self._has_waiting() or not self.running_list.is_empty()
|
return self._has_waiting() or not self.running_list.is_empty()
|
||||||
|
|
||||||
|
@ -309,9 +328,20 @@ class RequestHandler:
|
||||||
|
|
||||||
# sample the next tokens
|
# sample the next tokens
|
||||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||||
|
return sample_tokens
|
||||||
|
|
||||||
|
def append_next_tokens(self, sample_tokens: torch.Tensor):
|
||||||
|
assert sample_tokens.dim() == 1
|
||||||
|
n_elements = sample_tokens.size(0)
|
||||||
if not self.prefill_bb.is_empty:
|
if not self.prefill_bb.is_empty:
|
||||||
|
assert (
|
||||||
|
self.prefill_bb.current_batch_size == n_elements
|
||||||
|
), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}"
|
||||||
self.prefill_bb.append_batch_tokens(sample_tokens)
|
self.prefill_bb.append_batch_tokens(sample_tokens)
|
||||||
else:
|
else:
|
||||||
|
assert (
|
||||||
|
self.running_bb.current_batch_size == n_elements
|
||||||
|
), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}"
|
||||||
self.running_bb.append_batch_tokens(sample_tokens)
|
self.running_bb.append_batch_tokens(sample_tokens)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
|
|
@ -349,6 +349,26 @@ class KVCacheManager:
|
||||||
|
|
||||||
return seqs_to_recycle
|
return seqs_to_recycle
|
||||||
|
|
||||||
|
def allocate_n_tokens_from_block_tables(
|
||||||
|
self,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
bsz: int,
|
||||||
|
n: int,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Allocate logical cache blocks for `n` new tokens for a batch of sequences during decoding stage."""
|
||||||
|
assert block_tables.dim() == 2
|
||||||
|
assert context_lens.dim() == 1
|
||||||
|
|
||||||
|
bsz = block_tables.size(0) if bsz is None else bsz
|
||||||
|
assert bsz == 1, "Support bsz 1 for now" # TODO support bsz > 1
|
||||||
|
|
||||||
|
seqs_to_recycle = []
|
||||||
|
for i in range(n):
|
||||||
|
seqs_to_recycle += self.allocate_tokens_from_block_tables(block_tables, context_lens - n + i + 1, bsz)
|
||||||
|
|
||||||
|
return seqs_to_recycle
|
||||||
|
|
||||||
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
|
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
|
||||||
"""Allocate space asked on a single block in the block table, specified by the provided position id,
|
"""Allocate space asked on a single block in the block table, specified by the provided position id,
|
||||||
and updates the provided block table with the allocated block.
|
and updates the provided block table with the allocated block.
|
||||||
|
@ -420,9 +440,7 @@ class KVCacheManager:
|
||||||
Returns:
|
Returns:
|
||||||
The remaining space required to be allocated (in other blocks).
|
The remaining space required to be allocated (in other blocks).
|
||||||
"""
|
"""
|
||||||
assert (
|
assert block.available_space > 0, f"Found no available space left in the chosen block {block}."
|
||||||
block.available_space > 0
|
|
||||||
), "Tried to allocate some space but found no available space left in chosen block."
|
|
||||||
space_to_allocate = min(block.available_space, space_asked)
|
space_to_allocate = min(block.available_space, space_asked)
|
||||||
block.allocate(space_to_allocate)
|
block.allocate(space_to_allocate)
|
||||||
return space_asked - space_to_allocate
|
return space_asked - space_to_allocate
|
||||||
|
|
|
@ -18,6 +18,7 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import (
|
||||||
context_attention_unpadded,
|
context_attention_unpadded,
|
||||||
|
copy_k_to_blocked_cache,
|
||||||
decoding_fused_rotary_embedding,
|
decoding_fused_rotary_embedding,
|
||||||
flash_decoding_attention,
|
flash_decoding_attention,
|
||||||
get_xine_cache,
|
get_xine_cache,
|
||||||
|
@ -84,9 +85,9 @@ def llama_model_forward(
|
||||||
"""This function will replace the forward function of LlamaModel.
|
"""This function will replace the forward function of LlamaModel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (BatchInfo): It stores the necessary input information for this inference.
|
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
||||||
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
|
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
|
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||||
"""
|
"""
|
||||||
block_tables = inputmetadata.block_tables
|
block_tables = inputmetadata.block_tables
|
||||||
|
@ -101,7 +102,25 @@ def llama_model_forward(
|
||||||
use_cuda_kernel = False
|
use_cuda_kernel = False
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||||
if use_cuda_kernel:
|
cu_seqlens = None
|
||||||
|
|
||||||
|
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
|
||||||
|
if inputmetadata.use_spec_dec:
|
||||||
|
# For speculative-decoding Prefill and Verifying Stage
|
||||||
|
if inputmetadata.is_prompts:
|
||||||
|
# output tensor shape is the same as normal Prefill Stage
|
||||||
|
o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim)
|
||||||
|
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
|
||||||
|
else:
|
||||||
|
# the number of tokens to be verified in parallel plus the correct token in the last step
|
||||||
|
n_tokens = inputmetadata.num_tokens_to_verify + 1
|
||||||
|
assert n_tokens == hidden_states.size(0)
|
||||||
|
o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim)
|
||||||
|
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
|
||||||
|
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
|
||||||
|
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||||
|
|
||||||
|
elif use_cuda_kernel:
|
||||||
if inputmetadata != torch.float32 and use_flash_attn2:
|
if inputmetadata != torch.float32 and use_flash_attn2:
|
||||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||||
|
|
||||||
|
@ -113,14 +132,22 @@ def llama_model_forward(
|
||||||
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
|
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
|
||||||
)
|
)
|
||||||
cos_sin = (cos, sin)
|
cos_sin = (cos, sin)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cu_seqlens = None
|
|
||||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||||
|
|
||||||
|
# TODO (yuanheng-zhao): revise the logic here
|
||||||
|
# if batch.is_prompts:
|
||||||
|
# output_tensor = torch.zeros(
|
||||||
|
# (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# output_tensor = torch.zeros(
|
||||||
|
# (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
|
# )
|
||||||
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||||
|
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = torch.empty_like(hidden_states)
|
||||||
|
tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None
|
||||||
residual = None
|
residual = None
|
||||||
|
|
||||||
for layer_id, decoder_layer in enumerate(self.layers):
|
for layer_id, decoder_layer in enumerate(self.layers):
|
||||||
|
@ -131,6 +158,8 @@ def llama_model_forward(
|
||||||
k_cache=k_caches[layer_id],
|
k_cache=k_caches[layer_id],
|
||||||
v_cache=v_caches[layer_id],
|
v_cache=v_caches[layer_id],
|
||||||
is_prompts=inputmetadata.is_prompts,
|
is_prompts=inputmetadata.is_prompts,
|
||||||
|
is_verifier=inputmetadata.use_spec_dec,
|
||||||
|
tokens_to_verify=tokens_to_verify,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
cos_sin=cos_sin,
|
cos_sin=cos_sin,
|
||||||
fd_inter_tensor=inputmetadata.fd_inter_tensor,
|
fd_inter_tensor=inputmetadata.fd_inter_tensor,
|
||||||
|
@ -144,9 +173,9 @@ def llama_model_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputmetadata.is_prompts:
|
if inputmetadata.is_prompts:
|
||||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
seq_len_cumsum = sequence_lengths.cumsum(dim=0)
|
||||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
hidden_states = hidden_states[seq_len_cumsum - 1].contiguous()
|
||||||
residual = residual[last_token_indexs - 1].contiguous()
|
residual = residual[seq_len_cumsum - 1].contiguous()
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = torch.empty_like(hidden_states)
|
||||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||||
|
|
||||||
|
@ -164,6 +193,8 @@ def llama_decoder_layer_forward(
|
||||||
cos_sin: Tuple[torch.Tensor],
|
cos_sin: Tuple[torch.Tensor],
|
||||||
fd_inter_tensor: FDIntermTensors,
|
fd_inter_tensor: FDIntermTensors,
|
||||||
is_prompts: bool = True,
|
is_prompts: bool = True,
|
||||||
|
is_verifier: bool = False,
|
||||||
|
tokens_to_verify: int = None,
|
||||||
kv_seq_len: int = 0,
|
kv_seq_len: int = 0,
|
||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
norm_output: torch.Tensor = None,
|
norm_output: torch.Tensor = None,
|
||||||
|
@ -202,6 +233,9 @@ def llama_decoder_layer_forward(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
v_cache=v_cache,
|
v_cache=v_cache,
|
||||||
|
is_prompts=is_prompts,
|
||||||
|
is_verifier=is_verifier,
|
||||||
|
tokens_to_verify=tokens_to_verify,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
cos_sin=cos_sin,
|
cos_sin=cos_sin,
|
||||||
fd_inter_tensor=fd_inter_tensor,
|
fd_inter_tensor=fd_inter_tensor,
|
||||||
|
@ -312,6 +346,8 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
cos_sin: Tuple[torch.Tensor],
|
cos_sin: Tuple[torch.Tensor],
|
||||||
fd_inter_tensor: FDIntermTensors,
|
fd_inter_tensor: FDIntermTensors,
|
||||||
is_prompts: bool = True,
|
is_prompts: bool = True,
|
||||||
|
is_verifier: bool = False,
|
||||||
|
tokens_to_verify: int = None,
|
||||||
kv_seq_len: int = 0,
|
kv_seq_len: int = 0,
|
||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
|
@ -355,7 +391,7 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||||
# flash attn 2 currently only supports FP16/BF16.
|
# flash attn 2 currently only supports FP16/BF16.
|
||||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||||
inference_ops.context_kv_cache_memcpy(
|
inference_ops.context_kv_cache_memcpy(
|
||||||
|
@ -404,6 +440,16 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
block_tables,
|
block_tables,
|
||||||
high_precision,
|
high_precision,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||||
|
if is_verifier:
|
||||||
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
|
copy_k_to_blocked_cache(
|
||||||
|
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||||
|
)
|
||||||
|
copy_k_to_blocked_cache(
|
||||||
|
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoding_fused_rotary_embedding(
|
decoding_fused_rotary_embedding(
|
||||||
query_states,
|
query_states,
|
||||||
|
@ -428,8 +474,10 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
mid_output=fd_inter_tensor.mid_output,
|
mid_output=fd_inter_tensor.mid_output,
|
||||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
|
q_len=q_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(-1, self.hidden_size)
|
||||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
|
@ -15,93 +15,75 @@ class Drafter:
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The drafter model.
|
model (nn.Module): The drafter model.
|
||||||
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
|
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
|
||||||
max_spec_num (int): The maximum number of tokens to speculate.
|
|
||||||
device (torch.device): The device for the drafter model.
|
device (torch.device): The device for the drafter model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model: nn.Module, tokenizer: PreTrainedTokenizer, max_spec_num: int, device: torch.device = None
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
device: torch.device = None,
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
):
|
):
|
||||||
self._drafter_model = model
|
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self.max_spec_num = max_spec_num
|
|
||||||
self.do_sample = False
|
|
||||||
self.sample_fn = None
|
|
||||||
self._device = device or get_current_device()
|
self._device = device or get_current_device()
|
||||||
self._past_key_values = None
|
self._dtype = dtype
|
||||||
|
self._drafter_model = model.to(self._device)
|
||||||
@property
|
self._drafter_model = model.to(self._dtype)
|
||||||
def past_key_values(self) -> Optional[Tuple[Tuple[torch.FloatTensor]]]:
|
self._drafter_model.eval()
|
||||||
return self._past_key_values
|
|
||||||
|
|
||||||
# Debug usage for now
|
|
||||||
@property
|
|
||||||
def past_key_values_shape(self):
|
|
||||||
if self._past_key_values is None:
|
|
||||||
return []
|
|
||||||
return self._past_key_values[0][0].shape
|
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self._drafter_model
|
return self._drafter_model
|
||||||
|
|
||||||
def reset_sample_method(self, sample_fn: callable) -> None:
|
@staticmethod
|
||||||
self.do_sample = True
|
def trim_kv_cache(
|
||||||
self.sample_fn = sample_fn
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
|
||||||
|
) -> Tuple[Tuple[torch.FloatTensor]]:
|
||||||
|
"""Trim the last `invalid_token_num` kv caches.
|
||||||
|
|
||||||
def clear_sample_method(self) -> None:
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
|
||||||
self.do_sample = False
|
num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
|
||||||
self.sample_fn = None
|
invalid_token_num (int): The number of invalid tokens to trim.
|
||||||
|
"""
|
||||||
|
if past_key_values is None or invalid_token_num < 1:
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
def reset_max_spec_num(self, n: int) -> None:
|
|
||||||
assert isinstance(n, int) and n > 1
|
|
||||||
self.max_spec_num = n
|
|
||||||
|
|
||||||
def reset_past_key_values(self, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None) -> None:
|
|
||||||
self._past_key_values = past_key_values
|
|
||||||
|
|
||||||
def trim_kv_cache(self, invalid_token_num) -> Tuple[Tuple[torch.FloatTensor]]:
|
|
||||||
# Tuple of kv cache tensors: num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
|
|
||||||
# Trim the last `invalid_token_num` kv caches
|
|
||||||
# The verifier (main model) might reject `invalid_token_num` tokens,
|
|
||||||
# and so that we have to trim the invalid tokens for the kv cache of the drafter model.
|
|
||||||
assert self._past_key_values is not None
|
|
||||||
trimmed_past_key_values = []
|
trimmed_past_key_values = []
|
||||||
for layer_idx in range(len(self._past_key_values)):
|
for layer_idx in range(len(past_key_values)):
|
||||||
past_key_value = self._past_key_values[layer_idx]
|
past_key_value = past_key_values[layer_idx]
|
||||||
trimmed_past_key_values.append(
|
trimmed_past_key_values.append(
|
||||||
(
|
(
|
||||||
past_key_value[0][:, :, :-invalid_token_num, :],
|
past_key_value[0][:, :, :-invalid_token_num, :],
|
||||||
past_key_value[1][:, :, :-invalid_token_num, :],
|
past_key_value[1][:, :, :-invalid_token_num, :],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._past_key_values = tuple(trimmed_past_key_values)
|
past_key_values = tuple(trimmed_past_key_values)
|
||||||
return self._past_key_values
|
return past_key_values
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def speculate(
|
def speculate(
|
||||||
self, input_ids: torch.Tensor, n: int, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
n_spec_tokens: int,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
) -> DrafterOutput:
|
) -> DrafterOutput:
|
||||||
"""Generate n tokens using the drafter model.
|
"""Generate n_spec_tokens tokens using the drafter model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids (torch.Tensor): Input token ids.
|
input_ids (torch.Tensor): Input token ids.
|
||||||
n (int): Number of tokens to speculate.
|
n_spec_tokens (int): Number of tokens to speculate.
|
||||||
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
|
||||||
"""
|
"""
|
||||||
|
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
|
||||||
|
|
||||||
assert 0 <= n <= self.max_spec_num, f"Invalid number {n} to speculate"
|
# For compatibility with transformers of versions before 4.38.0
|
||||||
|
|
||||||
# FIXME For compatibility with transformers 4.36.2 (versions before 4.38.0)
|
|
||||||
if input_ids.dim() == 1:
|
if input_ids.dim() == 1:
|
||||||
input_ids = input_ids.unsqueeze(0)
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = self._past_key_values
|
|
||||||
|
|
||||||
logits = []
|
logits = []
|
||||||
token_ids = []
|
token_ids = []
|
||||||
|
|
||||||
for _ in range(n):
|
for _ in range(n_spec_tokens):
|
||||||
outputs = self._drafter_model(
|
outputs = self._drafter_model(
|
||||||
input_ids,
|
input_ids,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
@ -110,16 +92,9 @@ class Drafter:
|
||||||
)
|
)
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
# Skip logits_processor for drafter model
|
# NOTE Only use greedy search for speculating.
|
||||||
|
# As the drafter model usually has only a few layers with few parameters,
|
||||||
# Sample
|
# introducing sampling will make the speculation unstable and lead to worse performance.
|
||||||
if self.do_sample:
|
|
||||||
if self.sample_fn is not None:
|
|
||||||
probs = self.sample_fn(next_token_logits)
|
|
||||||
else:
|
|
||||||
probs = nn.functional.softmax(next_token_logits, dim=-1)
|
|
||||||
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
||||||
else:
|
|
||||||
next_token_ids = torch.argmax(next_token_logits, dim=-1)
|
next_token_ids = torch.argmax(next_token_logits, dim=-1)
|
||||||
|
|
||||||
logits.append(next_token_logits)
|
logits.append(next_token_logits)
|
||||||
|
@ -133,8 +108,6 @@ class Drafter:
|
||||||
speculated_length = len(token_ids) # TODO For now, only support bsz 1
|
speculated_length = len(token_ids) # TODO For now, only support bsz 1
|
||||||
logits = torch.concat(logits, dim=0)
|
logits = torch.concat(logits, dim=0)
|
||||||
token_ids = torch.concat(token_ids, dim=-1)
|
token_ids = torch.concat(token_ids, dim=-1)
|
||||||
# update past_key_values
|
|
||||||
self._past_key_values = past_key_values
|
|
||||||
|
|
||||||
out = DrafterOutput(
|
out = DrafterOutput(
|
||||||
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
|
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
|
||||||
|
|
|
@ -44,6 +44,7 @@ def _flash_decoding_fwd_kernel(
|
||||||
cur_seq_idx = cur_token_idx // q_len
|
cur_seq_idx = cur_token_idx // q_len
|
||||||
if cur_seq_idx >= batch_size:
|
if cur_seq_idx >= batch_size:
|
||||||
return
|
return
|
||||||
|
cur_token_off = (cur_token_idx % q_len) - q_len + 1
|
||||||
cur_head_idx = tl.program_id(1)
|
cur_head_idx = tl.program_id(1)
|
||||||
block_start_kv = tl.program_id(2) # for splitting k/v
|
block_start_kv = tl.program_id(2) # for splitting k/v
|
||||||
|
|
||||||
|
@ -52,7 +53,8 @@ def _flash_decoding_fwd_kernel(
|
||||||
# and then support calculating multiple kv cache blocks on an instance
|
# and then support calculating multiple kv cache blocks on an instance
|
||||||
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
||||||
# get the current (kv) sequence length
|
# get the current (kv) sequence length
|
||||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
# cur_token_off is used as a "mask" here for spec-dec during verification process
|
||||||
|
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -150,7 +152,9 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||||
return
|
return
|
||||||
cur_head_idx = tl.program_id(1)
|
cur_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
# cur_token_off is used as a "mask" here for spec-dec during verification process
|
||||||
|
cur_token_off = (cur_token_idx % q_len) - q_len + 1
|
||||||
|
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
|
|
||||||
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
|
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
|
||||||
|
|
|
@ -2,10 +2,15 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.config import GenerationConfig, InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
from colossalai.inference.spec.drafter import Drafter
|
from colossalai.inference.spec.drafter import Drafter
|
||||||
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
NUM_LAYERS = 2
|
NUM_LAYERS = 2
|
||||||
|
MAX_LEN = 100
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("spec_num", [5])
|
@pytest.mark.parametrize("spec_num", [5])
|
||||||
|
@ -14,13 +19,13 @@ def test_drafter(spec_num: int):
|
||||||
|
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||||
toy_config.pad_token_id = toy_config.eos_token_id
|
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||||
drafter_model = LlamaForCausalLM(toy_config)
|
drafter_model = LlamaForCausalLM(toy_config)
|
||||||
drafter_model = drafter_model.eval().cuda()
|
drafter_model = drafter_model.eval().cuda()
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
|
||||||
|
|
||||||
drafter = Drafter(drafter_model, tokenizer, spec_num, device=device)
|
drafter = Drafter(drafter_model, tokenizer, device=device)
|
||||||
|
|
||||||
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
|
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
|
||||||
out = drafter.speculate(input_ids, spec_num)
|
out = drafter.speculate(input_ids, spec_num)
|
||||||
|
@ -29,13 +34,75 @@ def test_drafter(spec_num: int):
|
||||||
assert out.speculated_length == spec_num
|
assert out.speculated_length == spec_num
|
||||||
assert out.next_tokens.shape == (spec_num,)
|
assert out.next_tokens.shape == (spec_num,)
|
||||||
assert out.logits.shape == (spec_num, len(tokenizer))
|
assert out.logits.shape == (spec_num, len(tokenizer))
|
||||||
assert drafter._past_key_values[0][0].size(2) == out.past_key_values[0][0].size(2) == past_kv_length
|
assert out.past_key_values[0][0].size(2) == past_kv_length
|
||||||
|
|
||||||
reject_num = 3
|
reject_num = max(0, spec_num - 1)
|
||||||
assert reject_num <= spec_num
|
trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num)
|
||||||
drafter.trim_kv_cache(reject_num)
|
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
||||||
assert drafter._past_key_values[0][0].size(2) == past_kv_length - reject_num
|
|
||||||
|
|
||||||
|
def check_sd():
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
# Dummy configs for testing
|
||||||
|
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||||
|
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
drafter_model = LlamaForCausalLM(toy_config)
|
||||||
|
drafter_model = drafter_model.eval().cuda()
|
||||||
|
large_config = LlamaConfig(
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_hidden_layers=8,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
)
|
||||||
|
large_config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
main_model = LlamaForCausalLM(large_config)
|
||||||
|
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
dtype="fp16",
|
||||||
|
micro_batch_size=1,
|
||||||
|
max_batch_size=1,
|
||||||
|
max_input_len=128,
|
||||||
|
max_output_len=128,
|
||||||
|
prefill_ratio=1.2,
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
engine = InferenceEngine(main_model, tokenizer, inference_config)
|
||||||
|
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||||
|
|
||||||
|
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
max_length=MAX_LEN,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
out, out_token_ids = engine.generate(
|
||||||
|
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||||
|
)
|
||||||
|
engine.disable_spec_dec()
|
||||||
|
engine.clear_spec_dec()
|
||||||
|
|
||||||
|
assert not engine.use_spec_dec
|
||||||
|
assert engine.drafter is None and engine.drafter_model is None
|
||||||
|
|
||||||
|
assert len(out) == 1
|
||||||
|
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_sd()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_spec_dec():
|
||||||
|
spawn(run_dist, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_drafter(spec_num=5)
|
test_drafter(spec_num=5)
|
||||||
|
test_spec_dec()
|
||||||
|
|
|
@ -19,12 +19,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
||||||
|
|
||||||
|
|
||||||
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
|
def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
|
||||||
|
assert q_len <= kv_len
|
||||||
|
|
||||||
|
causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)
|
||||||
|
|
||||||
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
|
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
cur_seq_len = kv_lengths[i].item()
|
cur_seq_len = kv_lengths[i].item()
|
||||||
assert cur_seq_len <= kv_len
|
assert cur_seq_len <= kv_len
|
||||||
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
|
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
|
||||||
|
|
||||||
|
padding_mask[:, :, -q_len:, -q_len:] += causal_mask
|
||||||
|
|
||||||
return padding_mask
|
return padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,11 +63,13 @@ def torch_attn_ref(
|
||||||
attn_scores = qk / (head_dim**0.5)
|
attn_scores = qk / (head_dim**0.5)
|
||||||
|
|
||||||
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
|
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
|
||||||
# for left-side padding
|
if attention_mask is not None:
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_len):
|
if attention_mask.size() != (bsz, 1, q_len, kv_len):
|
||||||
raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}")
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
attn_scores = attn_scores + attention_mask
|
attn_scores = attn_scores + attention_mask
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
||||||
out = torch.matmul(attn_weights, v)
|
out = torch.matmul(attn_weights, v)
|
||||||
if out.size() != (bsz, num_heads, q_len, head_dim):
|
if out.size() != (bsz, num_heads, q_len, head_dim):
|
||||||
|
|
|
@ -6,8 +6,8 @@ from colossalai.kernel.triton import flash_decoding_attention
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||||
convert_kv_unpad_to_padded,
|
convert_kv_unpad_to_padded,
|
||||||
|
create_attention_mask,
|
||||||
generate_caches_and_block_tables_v2,
|
generate_caches_and_block_tables_v2,
|
||||||
prepare_padding_mask,
|
|
||||||
torch_attn_ref,
|
torch_attn_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -91,9 +91,9 @@ def test_flash_decoding(
|
||||||
|
|
||||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
||||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
||||||
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
|
attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
|
||||||
out_torch = torch_attn_ref(
|
out_torch = torch_attn_ref(
|
||||||
q, k_torch, v_torch, torch_padding_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||||
)
|
)
|
||||||
|
|
||||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||||
|
@ -138,6 +138,5 @@ def test_flash_decoding(
|
||||||
assert out_torch.shape == out_triton.shape
|
assert out_torch.shape == out_triton.shape
|
||||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_flash_decoding(16, 32, 32, 16, 1, True)
|
test_flash_decoding(16, 32, 32, 16, 1, True)
|
||||||
|
|
Loading…
Reference in New Issue