Browse Source

[Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist
pull/5782/head
yuehuayingxueluo 6 months ago committed by GitHub
parent
commit
b45000f839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 9
      colossalai/inference/README.md
  2. 37
      colossalai/inference/batch_bucket.py
  3. 32
      colossalai/inference/config.py
  4. 12
      colossalai/inference/core/engine.py
  5. 12
      colossalai/inference/core/request_handler.py
  6. 26
      colossalai/inference/kv_cache/kvcache_manager.py
  7. 24
      examples/inference/llama/llama_generation.py
  8. 122
      tests/test_infer/test_streamingllm.py

9
colossalai/inference/README.md

@ -278,6 +278,7 @@ This project was written from scratch but we learned a lot from several other gr
- [vLLM](https://github.com/vllm-project/vllm)
- [flash-attention](https://github.com/Dao-AILab/flash-attention)
- [HuggingFace](https://huggingface.co)
- [StreamingLLM](https://github.com/mit-han-lab/streaming-llm)
If you wish to cite relevant research papars, you can find the reference below.
```bibtex
@ -301,4 +302,12 @@ If you wish to cite relevant research papars, you can find the reference below.
author={Dao, Tri},
year={2023}
}
# StreamingLLM
@article{xiao2023streamingllm,
title={Efficient Streaming Language Models with Attention Sinks},
author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike},
journal={arXiv},
year={2023}
}
```

37
colossalai/inference/batch_bucket.py

@ -31,6 +31,9 @@ class BatchBucket:
fd_interm_tensor=None,
device=None,
dtype=torch.float16,
enable_streamingllm: bool = False,
start_token_size: int = 4,
generated_token_size: int = 512,
):
self.num_heads = num_heads
self.head_dim = head_dim
@ -45,11 +48,18 @@ class BatchBucket:
self._use_spec_dec = False
self._num_tokens_to_verify = None
self.enable_streamingllm = enable_streamingllm
self.start_token_size = start_token_size
self.generated_token_size = generated_token_size
self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
if enable_streamingllm:
max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1
else:
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
self._block_tables_helper = torch.full_like(self._block_tables, -1)
@ -109,6 +119,33 @@ class BatchBucket:
out.append(seq.input_token_id + seq.output_token_id)
return out
def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int):
"""
Update sequence_lengths and block_tables when it is necessary to swap out a block.
"""
updated_block_ids = []
if self.current_batch_size > 0:
need_update = False
sequence_lengths_list = self._sequence_lengths.tolist()
block_tables_list = self._block_tables[: self._current_batch_size].tolist()
for batch_id in range(self.current_batch_size):
# We assume that the start token occupies the entire first block.
if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1:
need_update = True
sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1
block_id = block_tables_list[batch_id].pop(1)
updated_block_ids.append(block_id)
block_tables_list[batch_id].append(-1)
if need_update:
self._sequence_lengths = torch.tensor(
sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device
)
self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device)
return updated_block_ids
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,

32
colossalai/inference/config.py

@ -166,8 +166,9 @@ class InferenceConfig(RPC_PARAM):
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
@ -176,10 +177,12 @@ class InferenceConfig(RPC_PARAM):
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
"""
# NOTE: arrange configs according to their importance and frequency of usage
@ -208,6 +211,7 @@ class InferenceConfig(RPC_PARAM):
no_repeat_ngram_size: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
forced_eos_token_id: int = None
ignore_eos: bool = False
# speculative decoding configs
max_n_spec_tokens: int = 5
@ -221,15 +225,19 @@ class InferenceConfig(RPC_PARAM):
pp_size: int = 1
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False
# cuda kernel option
use_cuda_kernel: bool = False
high_precision: Optional[bool] = False
# cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
max_context_len_to_capture: int = 512
ignore_eos: bool = False
# StreamingLLM (sliding window attention with attention sinks)
enable_streamingllm: bool = False
start_token_size: int = 4
generated_token_size: int = 512
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
@ -260,6 +268,20 @@ class InferenceConfig(RPC_PARAM):
if self.dtype == torch.float32:
self.high_precision = False
# check StreamingLLM
assert (
self.start_token_size <= self.block_size
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
assert (
self.generated_token_size % self.block_size == 0
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
# Our StreamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized
# based on our framework's kvcache management mechanism. According to the paper, a start_token_size of 4 is sufficient. Therefore,
# we assume the start_token_size is less than or equal to the block size. When the start_token_size is smaller than the block size,
# we fill the first block with the start_token_size and subsequently generated tokens, using these as the "start tokens."
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
# check prompt template
if self.prompt_template is None:
return

12
colossalai/inference/core/engine.py

@ -667,6 +667,11 @@ class InferenceEngine:
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence(
request_id,
prompt,
@ -754,6 +759,13 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)

12
colossalai/inference/core/request_handler.py

@ -157,6 +157,9 @@ class RequestHandler:
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
enable_streamingllm=inference_config.enable_streamingllm,
start_token_size=inference_config.start_token_size,
generated_token_size=inference_config.generated_token_size,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
@ -168,6 +171,9 @@ class RequestHandler:
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
enable_streamingllm=inference_config.enable_streamingllm,
start_token_size=inference_config.start_token_size,
generated_token_size=inference_config.generated_token_size,
)
def _init_cache(self, model_config):
@ -350,6 +356,12 @@ class RequestHandler:
return finished_seqs
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
self.cache_manager.streamingllm_free_block_tables(updated_block_ids)
class RPCRequestHandler(RequestHandler):
"""

26
colossalai/inference/kv_cache/kvcache_manager.py vendored

@ -78,7 +78,13 @@ class KVCacheManager:
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
if config.enable_streamingllm:
self.max_blocks_per_sequence = (
config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
@ -446,6 +452,20 @@ class KVCacheManager:
self._available_blocks = self.num_blocks
self._block_states[:] = 1
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
for global_block_id in updated_block_ids:
if global_block_id < 0:
return
block: CacheBlock = self._cache_blocks[global_block_id]
block.remove_ref()
if not block.has_ref():
block.allocated_size = 0
self._available_blocks += 1
self._block_states[global_block_id] = 1
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
@ -533,7 +553,13 @@ class RPCKVCacheManager(KVCacheManager):
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
if config.enable_streamingllm:
self.max_blocks_per_sequence = (
config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size

24
examples/inference/llama/llama_generation.py

@ -48,6 +48,9 @@ def infer(args):
block_size=16,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
enable_streamingllm=args.enable_streamingllm,
start_token_size=args.start_token_size,
generated_token_size=args.generated_token_size,
)
coordinator.print_on_master(f"Initializing Inference Engine...")
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
@ -63,6 +66,8 @@ def infer(args):
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
no_repeat_ngram_size=args.no_repeat_ngram_size,
repetition_penalty=args.repetition_penalty,
)
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
@ -107,6 +112,25 @@ if __name__ == "__main__":
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM")
parser.add_argument(
"--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM,"
)
parser.add_argument(
"--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM"
)
parser.add_argument(
"--no_repeat_ngram_size",
type=int,
default=0,
help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.0,
help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.",
)
args = parser.parse_args()
infer(args)

122
tests/test_infer/test_streamingllm.py

@ -0,0 +1,122 @@
import random
import numpy as np
import torch
from torch.multiprocessing import Manager
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def data_gen(batch_size: int = 4, seq_len: int = 512):
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=torch.cuda.current_device())
return input_ids
def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_streamingllm():
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000,
hidden_size=512,
intermediate_size=1536,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=16,
)
).cuda()
model = model.eval()
input_token_ids = data_gen(1, 4)
output_len = 128
inference_config = InferenceConfig(
max_batch_size=1,
max_output_len=output_len,
dtype="fp32",
use_cuda_kernel=True,
enable_streamingllm=True,
start_token_size=4,
generated_token_size=32,
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts_token_ids=input_token_ids)
assert inference_engine.request_handler._has_waiting()
assert inference_config.start_token_size == inference_config.block_size
request_handler = inference_engine.request_handler
running_bb = request_handler.running_bb
for _ in range(12):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, -1, -1, -1]
assert running_bb.seq_lengths[0].item() == 16
for _ in range(16):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 1, -1, -1]
assert running_bb.seq_lengths[0].item() == 32
for _ in range(16):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 1, 2, -1]
assert running_bb.seq_lengths[0].item() == 48
for _ in range(16):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 2, 3, -1]
assert running_bb.seq_lengths[0].item() == 48
for _ in range(1):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 2, 3, 1]
assert running_bb.seq_lengths[0].item() == 49
for _ in range(15):
inference_engine.step()
assert running_bb.block_tables[0].tolist() == [0, 3, 1, -1]
assert running_bb.seq_lengths[0].item() == 48
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)
@rerun_if_address_is_in_use()
def test_engine():
manager = Manager()
result_list = manager.list([-1] * 1) # Create a shared list
spawn(run_dist, 1, func_to_run=check_streamingllm, ret=result_list)
return result_list[0]
if __name__ == "__main__":
test_engine()
Loading…
Cancel
Save