[Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist
pull/5782/head
yuehuayingxueluo 2024-06-05 10:51:19 +08:00 committed by GitHub
parent ee6fd38373
commit b45000f839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 276 additions and 12 deletions

View File

@ -278,6 +278,7 @@ This project was written from scratch but we learned a lot from several other gr
- [vLLM](https://github.com/vllm-project/vllm) - [vLLM](https://github.com/vllm-project/vllm)
- [flash-attention](https://github.com/Dao-AILab/flash-attention) - [flash-attention](https://github.com/Dao-AILab/flash-attention)
- [HuggingFace](https://huggingface.co) - [HuggingFace](https://huggingface.co)
- [StreamingLLM](https://github.com/mit-han-lab/streaming-llm)
If you wish to cite relevant research papars, you can find the reference below. If you wish to cite relevant research papars, you can find the reference below.
```bibtex ```bibtex
@ -301,4 +302,12 @@ If you wish to cite relevant research papars, you can find the reference below.
author={Dao, Tri}, author={Dao, Tri},
year={2023} year={2023}
} }
# StreamingLLM
@article{xiao2023streamingllm,
title={Efficient Streaming Language Models with Attention Sinks},
author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike},
journal={arXiv},
year={2023}
}
``` ```

View File

@ -31,6 +31,9 @@ class BatchBucket:
fd_interm_tensor=None, fd_interm_tensor=None,
device=None, device=None,
dtype=torch.float16, dtype=torch.float16,
enable_streamingllm: bool = False,
start_token_size: int = 4,
generated_token_size: int = 512,
): ):
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = head_dim self.head_dim = head_dim
@ -45,12 +48,19 @@ class BatchBucket:
self._use_spec_dec = False self._use_spec_dec = False
self._num_tokens_to_verify = None self._num_tokens_to_verify = None
self.enable_streamingllm = enable_streamingllm
self.start_token_size = start_token_size
self.generated_token_size = generated_token_size
self._current_batch_size = 0 self._current_batch_size = 0
self._sequences_dict = dict() self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32) self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths) self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size if enable_streamingllm:
max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1
else:
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32) self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
self._block_tables_helper = torch.full_like(self._block_tables, -1) self._block_tables_helper = torch.full_like(self._block_tables, -1)
@ -109,6 +119,33 @@ class BatchBucket:
out.append(seq.input_token_id + seq.output_token_id) out.append(seq.input_token_id + seq.output_token_id)
return out return out
def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int):
"""
Update sequence_lengths and block_tables when it is necessary to swap out a block.
"""
updated_block_ids = []
if self.current_batch_size > 0:
need_update = False
sequence_lengths_list = self._sequence_lengths.tolist()
block_tables_list = self._block_tables[: self._current_batch_size].tolist()
for batch_id in range(self.current_batch_size):
# We assume that the start token occupies the entire first block.
if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1:
need_update = True
sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1
block_id = block_tables_list[batch_id].pop(1)
updated_block_ids.append(block_id)
block_tables_list[batch_id].append(-1)
if need_update:
self._sequence_lengths = torch.tensor(
sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device
)
self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device)
return updated_block_ids
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding. """Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling, This will notify the adjust the lengths of inputs during modeling,

View File

@ -166,8 +166,9 @@ class InferenceConfig(RPC_PARAM):
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0. temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16. block_size (int): The number of blocks in a logical block, defaults to 16.
@ -176,10 +177,12 @@ class InferenceConfig(RPC_PARAM):
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1. micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
""" """
# NOTE: arrange configs according to their importance and frequency of usage # NOTE: arrange configs according to their importance and frequency of usage
@ -208,6 +211,7 @@ class InferenceConfig(RPC_PARAM):
no_repeat_ngram_size: Optional[int] = 0 no_repeat_ngram_size: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
forced_eos_token_id: int = None forced_eos_token_id: int = None
ignore_eos: bool = False
# speculative decoding configs # speculative decoding configs
max_n_spec_tokens: int = 5 max_n_spec_tokens: int = 5
@ -221,15 +225,19 @@ class InferenceConfig(RPC_PARAM):
pp_size: int = 1 pp_size: int = 1
micro_batch_size: int = 1 micro_batch_size: int = 1
micro_batch_buffer_size: int = None micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False
# cuda kernel option # cuda kernel option
use_cuda_kernel: bool = False use_cuda_kernel: bool = False
high_precision: Optional[bool] = False
# cuda_graph # cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
max_context_len_to_capture: int = 512 max_context_len_to_capture: int = 512
ignore_eos: bool = False
# StreamingLLM (sliding window attention with attention sinks)
enable_streamingllm: bool = False
start_token_size: int = 4
generated_token_size: int = 512
def __post_init__(self): def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len self.max_context_len_to_capture = self.max_input_len + self.max_output_len
@ -260,6 +268,20 @@ class InferenceConfig(RPC_PARAM):
if self.dtype == torch.float32: if self.dtype == torch.float32:
self.high_precision = False self.high_precision = False
# check StreamingLLM
assert (
self.start_token_size <= self.block_size
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
assert (
self.generated_token_size % self.block_size == 0
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
# Our StreamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized
# based on our framework's kvcache management mechanism. According to the paper, a start_token_size of 4 is sufficient. Therefore,
# we assume the start_token_size is less than or equal to the block size. When the start_token_size is smaller than the block size,
# we fill the first block with the start_token_size and subsequently generated tokens, using these as the "start tokens."
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
# check prompt template # check prompt template
if self.prompt_template is None: if self.prompt_template is None:
return return

View File

@ -667,6 +667,11 @@ class InferenceEngine:
elif max_length is not None: elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i]) max_new_tokens = max_length - len(prompts_token_ids[i])
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence( sequence = Sequence(
request_id, request_id,
prompt, prompt,
@ -754,6 +759,13 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens( next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
) )

View File

@ -157,6 +157,9 @@ class RequestHandler:
fd_interm_tensor=fd_inter_tensor, fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
enable_streamingllm=inference_config.enable_streamingllm,
start_token_size=inference_config.start_token_size,
generated_token_size=inference_config.generated_token_size,
) )
self.prefill_bb = BatchBucket( self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size, num_heads=model_config.num_attention_heads // inference_config.tp_size,
@ -168,6 +171,9 @@ class RequestHandler:
fd_interm_tensor=fd_inter_tensor, fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
enable_streamingllm=inference_config.enable_streamingllm,
start_token_size=inference_config.start_token_size,
generated_token_size=inference_config.generated_token_size,
) )
def _init_cache(self, model_config): def _init_cache(self, model_config):
@ -350,6 +356,12 @@ class RequestHandler:
return finished_seqs return finished_seqs
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
self.cache_manager.streamingllm_free_block_tables(updated_block_ids)
class RPCRequestHandler(RequestHandler): class RPCRequestHandler(RequestHandler):
""" """

View File

@ -78,10 +78,16 @@ class KVCacheManager:
self.max_output_length = config.max_output_len self.max_output_length = config.max_output_len
# Cache block settings # Cache block settings
self.block_size = config.block_size self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
self.max_blocks_per_sequence = ( if config.enable_streamingllm:
self.max_input_length + self.max_output_length + self.block_size - 1 self.max_blocks_per_sequence = (
) // self.block_size config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation # Physical cache allocation
@ -446,6 +452,20 @@ class KVCacheManager:
self._available_blocks = self.num_blocks self._available_blocks = self.num_blocks
self._block_states[:] = 1 self._block_states[:] = 1
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
for global_block_id in updated_block_ids:
if global_block_id < 0:
return
block: CacheBlock = self._cache_blocks[global_block_id]
block.remove_ref()
if not block.has_ref():
block.allocated_size = 0
self._available_blocks += 1
self._block_states[global_block_id] = 1
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer.""" """Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx] return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
@ -533,10 +553,16 @@ class RPCKVCacheManager(KVCacheManager):
self.max_output_length = config.max_output_len self.max_output_length = config.max_output_len
# Cache block settings # Cache block settings
self.block_size = config.block_size self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
self.max_blocks_per_sequence = ( if config.enable_streamingllm:
self.max_input_length + self.max_output_length + self.block_size - 1 self.max_blocks_per_sequence = (
) // self.block_size config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Logical cache blocks allocation # Logical cache blocks allocation

View File

@ -48,6 +48,9 @@ def infer(args):
block_size=16, block_size=16,
tp_size=args.tp_size, tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel, use_cuda_kernel=args.use_cuda_kernel,
enable_streamingllm=args.enable_streamingllm,
start_token_size=args.start_token_size,
generated_token_size=args.generated_token_size,
) )
coordinator.print_on_master(f"Initializing Inference Engine...") coordinator.print_on_master(f"Initializing Inference Engine...")
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True) engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
@ -63,6 +66,8 @@ def infer(args):
temperature=args.temperature, temperature=args.temperature,
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p, top_p=args.top_p,
no_repeat_ngram_size=args.no_repeat_ngram_size,
repetition_penalty=args.repetition_penalty,
) )
coordinator.print_on_master(f"Generating...") coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=generation_config) out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
@ -107,6 +112,25 @@ if __name__ == "__main__":
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation") parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation") parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation") parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM")
parser.add_argument(
"--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM,"
)
parser.add_argument(
"--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM"
)
parser.add_argument(
"--no_repeat_ngram_size",
type=int,
default=0,
help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.0,
help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.",
)
args = parser.parse_args() args = parser.parse_args()
infer(args) infer(args)

View File

@ -0,0 +1,122 @@
import random
import numpy as np
import torch
from torch.multiprocessing import Manager
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def data_gen(batch_size: int = 4, seq_len: int = 512):
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=torch.cuda.current_device())
return input_ids
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_streamingllm():
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000,
hidden_size=512,
intermediate_size=1536,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=16,
)
).cuda()
model = model.eval()
input_token_ids = data_gen(1, 4)
output_len = 128
inference_config = InferenceConfig(
max_batch_size=1,
max_output_len=output_len,
dtype="fp32",
use_cuda_kernel=True,
enable_streamingllm=True,
start_token_size=4,
generated_token_size=32,
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts_token_ids=input_token_ids)
assert inference_engine.request_handler._has_waiting()
assert inference_config.start_token_size == inference_config.block_size
request_handler = inference_engine.request_handler
running_bb = request_handler.running_bb
for _ in range(12):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, -1, -1, -1]
assert running_bb.seq_lengths[0].item() == 16
for _ in range(16):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 1, -1, -1]
assert running_bb.seq_lengths[0].item() == 32
for _ in range(16):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 1, 2, -1]
assert running_bb.seq_lengths[0].item() == 48
for _ in range(16):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 2, 3, -1]
assert running_bb.seq_lengths[0].item() == 48
for _ in range(1):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 2, 3, 1]
assert running_bb.seq_lengths[0].item() == 49
for _ in range(15):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 3, 1, -1]
assert running_bb.seq_lengths[0].item() == 48
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
@rerun_if_address_is_in_use()
def test_engine():
manager = Manager()
result_list = manager.list([-1] * 1) # Create a shared list
spawn(run_dist, 1, func_to_run=check_streamingllm, ret=result_list)
return result_list[0]
if __name__ == "__main__":
test_engine()