mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Add Streaming LLM (#5745)
* Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolistpull/5782/head
parent
ee6fd38373
commit
b45000f839
|
@ -278,6 +278,7 @@ This project was written from scratch but we learned a lot from several other gr
|
|||
- [vLLM](https://github.com/vllm-project/vllm)
|
||||
- [flash-attention](https://github.com/Dao-AILab/flash-attention)
|
||||
- [HuggingFace](https://huggingface.co)
|
||||
- [StreamingLLM](https://github.com/mit-han-lab/streaming-llm)
|
||||
If you wish to cite relevant research papars, you can find the reference below.
|
||||
|
||||
```bibtex
|
||||
|
@ -301,4 +302,12 @@ If you wish to cite relevant research papars, you can find the reference below.
|
|||
author={Dao, Tri},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
# StreamingLLM
|
||||
@article{xiao2023streamingllm,
|
||||
title={Efficient Streaming Language Models with Attention Sinks},
|
||||
author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -31,6 +31,9 @@ class BatchBucket:
|
|||
fd_interm_tensor=None,
|
||||
device=None,
|
||||
dtype=torch.float16,
|
||||
enable_streamingllm: bool = False,
|
||||
start_token_size: int = 4,
|
||||
generated_token_size: int = 512,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
@ -45,12 +48,19 @@ class BatchBucket:
|
|||
self._use_spec_dec = False
|
||||
self._num_tokens_to_verify = None
|
||||
|
||||
self.enable_streamingllm = enable_streamingllm
|
||||
self.start_token_size = start_token_size
|
||||
self.generated_token_size = generated_token_size
|
||||
|
||||
self._current_batch_size = 0
|
||||
self._sequences_dict = dict()
|
||||
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
||||
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
|
||||
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
|
||||
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
|
||||
if enable_streamingllm:
|
||||
max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1
|
||||
else:
|
||||
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
|
||||
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
|
||||
self._block_tables_helper = torch.full_like(self._block_tables, -1)
|
||||
|
||||
|
@ -109,6 +119,33 @@ class BatchBucket:
|
|||
out.append(seq.input_token_id + seq.output_token_id)
|
||||
return out
|
||||
|
||||
def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int):
|
||||
"""
|
||||
Update sequence_lengths and block_tables when it is necessary to swap out a block.
|
||||
"""
|
||||
|
||||
updated_block_ids = []
|
||||
|
||||
if self.current_batch_size > 0:
|
||||
need_update = False
|
||||
sequence_lengths_list = self._sequence_lengths.tolist()
|
||||
block_tables_list = self._block_tables[: self._current_batch_size].tolist()
|
||||
for batch_id in range(self.current_batch_size):
|
||||
# We assume that the start token occupies the entire first block.
|
||||
if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1:
|
||||
need_update = True
|
||||
sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1
|
||||
block_id = block_tables_list[batch_id].pop(1)
|
||||
updated_block_ids.append(block_id)
|
||||
block_tables_list[batch_id].append(-1)
|
||||
if need_update:
|
||||
self._sequence_lengths = torch.tensor(
|
||||
sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device
|
||||
)
|
||||
self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device)
|
||||
|
||||
return updated_block_ids
|
||||
|
||||
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
|
||||
"""Set batch bucket to use speculatvie decoding.
|
||||
This will notify the adjust the lengths of inputs during modeling,
|
||||
|
|
|
@ -166,8 +166,9 @@ class InferenceConfig(RPC_PARAM):
|
|||
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
||||
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
||||
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
|
||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||
|
@ -176,10 +177,12 @@ class InferenceConfig(RPC_PARAM):
|
|||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
|
||||
start_token_size(int): The size of the start tokens, when using StreamingLLM.
|
||||
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
|
@ -208,6 +211,7 @@ class InferenceConfig(RPC_PARAM):
|
|||
no_repeat_ngram_size: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
forced_eos_token_id: int = None
|
||||
ignore_eos: bool = False
|
||||
|
||||
# speculative decoding configs
|
||||
max_n_spec_tokens: int = 5
|
||||
|
@ -221,15 +225,19 @@ class InferenceConfig(RPC_PARAM):
|
|||
pp_size: int = 1
|
||||
micro_batch_size: int = 1
|
||||
micro_batch_buffer_size: int = None
|
||||
high_precision: Optional[bool] = False
|
||||
|
||||
# cuda kernel option
|
||||
use_cuda_kernel: bool = False
|
||||
high_precision: Optional[bool] = False
|
||||
|
||||
# cuda_graph
|
||||
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
max_context_len_to_capture: int = 512
|
||||
ignore_eos: bool = False
|
||||
|
||||
# StreamingLLM (sliding window attention with attention sinks)
|
||||
enable_streamingllm: bool = False
|
||||
start_token_size: int = 4
|
||||
generated_token_size: int = 512
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||
|
@ -260,6 +268,20 @@ class InferenceConfig(RPC_PARAM):
|
|||
if self.dtype == torch.float32:
|
||||
self.high_precision = False
|
||||
|
||||
# check StreamingLLM
|
||||
assert (
|
||||
self.start_token_size <= self.block_size
|
||||
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
|
||||
assert (
|
||||
self.generated_token_size % self.block_size == 0
|
||||
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
|
||||
# Our StreamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized
|
||||
# based on our framework's kvcache management mechanism. According to the paper, a start_token_size of 4 is sufficient. Therefore,
|
||||
# we assume the start_token_size is less than or equal to the block size. When the start_token_size is smaller than the block size,
|
||||
# we fill the first block with the start_token_size and subsequently generated tokens, using these as the "start tokens."
|
||||
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
|
||||
self.start_token_size = self.block_size
|
||||
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
|
|
|
@ -667,6 +667,11 @@ class InferenceEngine:
|
|||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
if not self.inference_config.enable_streamingllm:
|
||||
assert (
|
||||
self.inference_config.max_output_len >= max_new_tokens
|
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
|
@ -754,6 +759,13 @@ class InferenceEngine:
|
|||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.inference_config.enable_streamingllm:
|
||||
updated_block_ids = batch.streamingllm_update_batch(
|
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||
)
|
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
|
|
|
@ -157,6 +157,9 @@ class RequestHandler:
|
|||
fd_interm_tensor=fd_inter_tensor,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
enable_streamingllm=inference_config.enable_streamingllm,
|
||||
start_token_size=inference_config.start_token_size,
|
||||
generated_token_size=inference_config.generated_token_size,
|
||||
)
|
||||
self.prefill_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
|
@ -168,6 +171,9 @@ class RequestHandler:
|
|||
fd_interm_tensor=fd_inter_tensor,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
enable_streamingllm=inference_config.enable_streamingllm,
|
||||
start_token_size=inference_config.start_token_size,
|
||||
generated_token_size=inference_config.generated_token_size,
|
||||
)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
|
@ -350,6 +356,12 @@ class RequestHandler:
|
|||
|
||||
return finished_seqs
|
||||
|
||||
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
|
||||
"""
|
||||
Free the block that needs to be swapped out.
|
||||
"""
|
||||
self.cache_manager.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
|
||||
class RPCRequestHandler(RequestHandler):
|
||||
"""
|
||||
|
|
|
@ -78,10 +78,16 @@ class KVCacheManager:
|
|||
self.max_output_length = config.max_output_len
|
||||
# Cache block settings
|
||||
self.block_size = config.block_size
|
||||
|
||||
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
||||
self.max_blocks_per_sequence = (
|
||||
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||
) // self.block_size
|
||||
if config.enable_streamingllm:
|
||||
self.max_blocks_per_sequence = (
|
||||
config.start_token_size + config.generated_token_size + self.block_size - 1
|
||||
) // self.block_size + 1
|
||||
else:
|
||||
self.max_blocks_per_sequence = (
|
||||
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||
) // self.block_size
|
||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||
|
||||
# Physical cache allocation
|
||||
|
@ -446,6 +452,20 @@ class KVCacheManager:
|
|||
self._available_blocks = self.num_blocks
|
||||
self._block_states[:] = 1
|
||||
|
||||
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
|
||||
"""
|
||||
Free the block that needs to be swapped out.
|
||||
"""
|
||||
for global_block_id in updated_block_ids:
|
||||
if global_block_id < 0:
|
||||
return
|
||||
block: CacheBlock = self._cache_blocks[global_block_id]
|
||||
block.remove_ref()
|
||||
if not block.has_ref():
|
||||
block.allocated_size = 0
|
||||
self._available_blocks += 1
|
||||
self._block_states[global_block_id] = 1
|
||||
|
||||
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
|
||||
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
|
||||
|
@ -533,10 +553,16 @@ class RPCKVCacheManager(KVCacheManager):
|
|||
self.max_output_length = config.max_output_len
|
||||
# Cache block settings
|
||||
self.block_size = config.block_size
|
||||
|
||||
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
||||
self.max_blocks_per_sequence = (
|
||||
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||
) // self.block_size
|
||||
if config.enable_streamingllm:
|
||||
self.max_blocks_per_sequence = (
|
||||
config.start_token_size + config.generated_token_size + self.block_size - 1
|
||||
) // self.block_size + 1
|
||||
else:
|
||||
self.max_blocks_per_sequence = (
|
||||
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||
) // self.block_size
|
||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||
|
||||
# Logical cache blocks allocation
|
||||
|
|
|
@ -48,6 +48,9 @@ def infer(args):
|
|||
block_size=16,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=args.use_cuda_kernel,
|
||||
enable_streamingllm=args.enable_streamingllm,
|
||||
start_token_size=args.start_token_size,
|
||||
generated_token_size=args.generated_token_size,
|
||||
)
|
||||
coordinator.print_on_master(f"Initializing Inference Engine...")
|
||||
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||
|
@ -63,6 +66,8 @@ def infer(args):
|
|||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
)
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
|
||||
|
@ -107,6 +112,25 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
|
||||
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
|
||||
parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM")
|
||||
parser.add_argument(
|
||||
"--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM,"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_repeat_ngram_size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repetition_penalty",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
infer(args)
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.multiprocessing import Manager
|
||||
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
return input_ids
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def check_streamingllm():
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000,
|
||||
hidden_size=512,
|
||||
intermediate_size=1536,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=16,
|
||||
)
|
||||
).cuda()
|
||||
model = model.eval()
|
||||
|
||||
input_token_ids = data_gen(1, 4)
|
||||
|
||||
output_len = 128
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
max_batch_size=1,
|
||||
max_output_len=output_len,
|
||||
dtype="fp32",
|
||||
use_cuda_kernel=True,
|
||||
enable_streamingllm=True,
|
||||
start_token_size=4,
|
||||
generated_token_size=32,
|
||||
)
|
||||
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts_token_ids=input_token_ids)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
|
||||
assert inference_config.start_token_size == inference_config.block_size
|
||||
|
||||
request_handler = inference_engine.request_handler
|
||||
running_bb = request_handler.running_bb
|
||||
|
||||
for _ in range(12):
|
||||
inference_engine.step()
|
||||
|
||||
assert running_bb.block_tables[0].tolist() == [0, -1, -1, -1]
|
||||
assert running_bb.seq_lengths[0].item() == 16
|
||||
|
||||
for _ in range(16):
|
||||
inference_engine.step()
|
||||
|
||||
assert running_bb.block_tables[0].tolist() == [0, 1, -1, -1]
|
||||
assert running_bb.seq_lengths[0].item() == 32
|
||||
|
||||
for _ in range(16):
|
||||
inference_engine.step()
|
||||
|
||||
assert running_bb.block_tables[0].tolist() == [0, 1, 2, -1]
|
||||
assert running_bb.seq_lengths[0].item() == 48
|
||||
|
||||
for _ in range(16):
|
||||
inference_engine.step()
|
||||
|
||||
assert running_bb.block_tables[0].tolist() == [0, 2, 3, -1]
|
||||
assert running_bb.seq_lengths[0].item() == 48
|
||||
|
||||
for _ in range(1):
|
||||
inference_engine.step()
|
||||
|
||||
assert running_bb.block_tables[0].tolist() == [0, 2, 3, 1]
|
||||
assert running_bb.seq_lengths[0].item() == 49
|
||||
|
||||
for _ in range(15):
|
||||
inference_engine.step()
|
||||
|
||||
assert running_bb.block_tables[0].tolist() == [0, 3, 1, -1]
|
||||
assert running_bb.seq_lengths[0].item() == 48
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
else:
|
||||
func_to_run(**kwargs)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_engine():
|
||||
manager = Manager()
|
||||
result_list = manager.list([-1] * 1) # Create a shared list
|
||||
|
||||
spawn(run_dist, 1, func_to_run=check_streamingllm, ret=result_list)
|
||||
return result_list[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_engine()
|
Loading…
Reference in New Issue