mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Add Streaming LLM (#5745)
* Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolistpull/5782/head
parent
ee6fd38373
commit
b45000f839
|
@ -278,6 +278,7 @@ This project was written from scratch but we learned a lot from several other gr
|
||||||
- [vLLM](https://github.com/vllm-project/vllm)
|
- [vLLM](https://github.com/vllm-project/vllm)
|
||||||
- [flash-attention](https://github.com/Dao-AILab/flash-attention)
|
- [flash-attention](https://github.com/Dao-AILab/flash-attention)
|
||||||
- [HuggingFace](https://huggingface.co)
|
- [HuggingFace](https://huggingface.co)
|
||||||
|
- [StreamingLLM](https://github.com/mit-han-lab/streaming-llm)
|
||||||
If you wish to cite relevant research papars, you can find the reference below.
|
If you wish to cite relevant research papars, you can find the reference below.
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
|
@ -301,4 +302,12 @@ If you wish to cite relevant research papars, you can find the reference below.
|
||||||
author={Dao, Tri},
|
author={Dao, Tri},
|
||||||
year={2023}
|
year={2023}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# StreamingLLM
|
||||||
|
@article{xiao2023streamingllm,
|
||||||
|
title={Efficient Streaming Language Models with Attention Sinks},
|
||||||
|
author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
|
@ -31,6 +31,9 @@ class BatchBucket:
|
||||||
fd_interm_tensor=None,
|
fd_interm_tensor=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
|
enable_streamingllm: bool = False,
|
||||||
|
start_token_size: int = 4,
|
||||||
|
generated_token_size: int = 512,
|
||||||
):
|
):
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
|
@ -45,12 +48,19 @@ class BatchBucket:
|
||||||
self._use_spec_dec = False
|
self._use_spec_dec = False
|
||||||
self._num_tokens_to_verify = None
|
self._num_tokens_to_verify = None
|
||||||
|
|
||||||
|
self.enable_streamingllm = enable_streamingllm
|
||||||
|
self.start_token_size = start_token_size
|
||||||
|
self.generated_token_size = generated_token_size
|
||||||
|
|
||||||
self._current_batch_size = 0
|
self._current_batch_size = 0
|
||||||
self._sequences_dict = dict()
|
self._sequences_dict = dict()
|
||||||
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
||||||
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
|
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
|
||||||
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
|
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
|
||||||
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
|
if enable_streamingllm:
|
||||||
|
max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1
|
||||||
|
else:
|
||||||
|
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
|
||||||
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
|
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
|
||||||
self._block_tables_helper = torch.full_like(self._block_tables, -1)
|
self._block_tables_helper = torch.full_like(self._block_tables, -1)
|
||||||
|
|
||||||
|
@ -109,6 +119,33 @@ class BatchBucket:
|
||||||
out.append(seq.input_token_id + seq.output_token_id)
|
out.append(seq.input_token_id + seq.output_token_id)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int):
|
||||||
|
"""
|
||||||
|
Update sequence_lengths and block_tables when it is necessary to swap out a block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
updated_block_ids = []
|
||||||
|
|
||||||
|
if self.current_batch_size > 0:
|
||||||
|
need_update = False
|
||||||
|
sequence_lengths_list = self._sequence_lengths.tolist()
|
||||||
|
block_tables_list = self._block_tables[: self._current_batch_size].tolist()
|
||||||
|
for batch_id in range(self.current_batch_size):
|
||||||
|
# We assume that the start token occupies the entire first block.
|
||||||
|
if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1:
|
||||||
|
need_update = True
|
||||||
|
sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1
|
||||||
|
block_id = block_tables_list[batch_id].pop(1)
|
||||||
|
updated_block_ids.append(block_id)
|
||||||
|
block_tables_list[batch_id].append(-1)
|
||||||
|
if need_update:
|
||||||
|
self._sequence_lengths = torch.tensor(
|
||||||
|
sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device
|
||||||
|
)
|
||||||
|
self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device)
|
||||||
|
|
||||||
|
return updated_block_ids
|
||||||
|
|
||||||
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
|
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
|
||||||
"""Set batch bucket to use speculatvie decoding.
|
"""Set batch bucket to use speculatvie decoding.
|
||||||
This will notify the adjust the lengths of inputs during modeling,
|
This will notify the adjust the lengths of inputs during modeling,
|
||||||
|
|
|
@ -166,8 +166,9 @@ class InferenceConfig(RPC_PARAM):
|
||||||
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
||||||
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
||||||
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
|
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
|
||||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
|
||||||
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
||||||
|
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||||
|
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||||
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||||
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||||
|
@ -176,10 +177,12 @@ class InferenceConfig(RPC_PARAM):
|
||||||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||||
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||||
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||||
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
|
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
|
||||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
|
||||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
start_token_size(int): The size of the start tokens, when using StreamingLLM.
|
||||||
|
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# NOTE: arrange configs according to their importance and frequency of usage
|
# NOTE: arrange configs according to their importance and frequency of usage
|
||||||
|
@ -208,6 +211,7 @@ class InferenceConfig(RPC_PARAM):
|
||||||
no_repeat_ngram_size: Optional[int] = 0
|
no_repeat_ngram_size: Optional[int] = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
forced_eos_token_id: int = None
|
forced_eos_token_id: int = None
|
||||||
|
ignore_eos: bool = False
|
||||||
|
|
||||||
# speculative decoding configs
|
# speculative decoding configs
|
||||||
max_n_spec_tokens: int = 5
|
max_n_spec_tokens: int = 5
|
||||||
|
@ -221,15 +225,19 @@ class InferenceConfig(RPC_PARAM):
|
||||||
pp_size: int = 1
|
pp_size: int = 1
|
||||||
micro_batch_size: int = 1
|
micro_batch_size: int = 1
|
||||||
micro_batch_buffer_size: int = None
|
micro_batch_buffer_size: int = None
|
||||||
high_precision: Optional[bool] = False
|
|
||||||
|
|
||||||
# cuda kernel option
|
# cuda kernel option
|
||||||
use_cuda_kernel: bool = False
|
use_cuda_kernel: bool = False
|
||||||
|
high_precision: Optional[bool] = False
|
||||||
|
|
||||||
# cuda_graph
|
# cuda_graph
|
||||||
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
max_context_len_to_capture: int = 512
|
max_context_len_to_capture: int = 512
|
||||||
ignore_eos: bool = False
|
|
||||||
|
# StreamingLLM (sliding window attention with attention sinks)
|
||||||
|
enable_streamingllm: bool = False
|
||||||
|
start_token_size: int = 4
|
||||||
|
generated_token_size: int = 512
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||||
|
@ -260,6 +268,20 @@ class InferenceConfig(RPC_PARAM):
|
||||||
if self.dtype == torch.float32:
|
if self.dtype == torch.float32:
|
||||||
self.high_precision = False
|
self.high_precision = False
|
||||||
|
|
||||||
|
# check StreamingLLM
|
||||||
|
assert (
|
||||||
|
self.start_token_size <= self.block_size
|
||||||
|
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
|
||||||
|
assert (
|
||||||
|
self.generated_token_size % self.block_size == 0
|
||||||
|
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
|
||||||
|
# Our StreamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized
|
||||||
|
# based on our framework's kvcache management mechanism. According to the paper, a start_token_size of 4 is sufficient. Therefore,
|
||||||
|
# we assume the start_token_size is less than or equal to the block size. When the start_token_size is smaller than the block size,
|
||||||
|
# we fill the first block with the start_token_size and subsequently generated tokens, using these as the "start tokens."
|
||||||
|
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
|
||||||
|
self.start_token_size = self.block_size
|
||||||
|
|
||||||
# check prompt template
|
# check prompt template
|
||||||
if self.prompt_template is None:
|
if self.prompt_template is None:
|
||||||
return
|
return
|
||||||
|
|
|
@ -667,6 +667,11 @@ class InferenceEngine:
|
||||||
elif max_length is not None:
|
elif max_length is not None:
|
||||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||||
|
|
||||||
|
if not self.inference_config.enable_streamingllm:
|
||||||
|
assert (
|
||||||
|
self.inference_config.max_output_len >= max_new_tokens
|
||||||
|
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||||
|
|
||||||
sequence = Sequence(
|
sequence = Sequence(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
|
@ -754,6 +759,13 @@ class InferenceEngine:
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if self.inference_config.enable_streamingllm:
|
||||||
|
updated_block_ids = batch.streamingllm_update_batch(
|
||||||
|
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||||
|
)
|
||||||
|
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||||
|
|
||||||
next_tokens = search_tokens(
|
next_tokens = search_tokens(
|
||||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||||
)
|
)
|
||||||
|
|
|
@ -157,6 +157,9 @@ class RequestHandler:
|
||||||
fd_interm_tensor=fd_inter_tensor,
|
fd_interm_tensor=fd_inter_tensor,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
enable_streamingllm=inference_config.enable_streamingllm,
|
||||||
|
start_token_size=inference_config.start_token_size,
|
||||||
|
generated_token_size=inference_config.generated_token_size,
|
||||||
)
|
)
|
||||||
self.prefill_bb = BatchBucket(
|
self.prefill_bb = BatchBucket(
|
||||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||||
|
@ -168,6 +171,9 @@ class RequestHandler:
|
||||||
fd_interm_tensor=fd_inter_tensor,
|
fd_interm_tensor=fd_inter_tensor,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
enable_streamingllm=inference_config.enable_streamingllm,
|
||||||
|
start_token_size=inference_config.start_token_size,
|
||||||
|
generated_token_size=inference_config.generated_token_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_cache(self, model_config):
|
def _init_cache(self, model_config):
|
||||||
|
@ -350,6 +356,12 @@ class RequestHandler:
|
||||||
|
|
||||||
return finished_seqs
|
return finished_seqs
|
||||||
|
|
||||||
|
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
|
||||||
|
"""
|
||||||
|
Free the block that needs to be swapped out.
|
||||||
|
"""
|
||||||
|
self.cache_manager.streamingllm_free_block_tables(updated_block_ids)
|
||||||
|
|
||||||
|
|
||||||
class RPCRequestHandler(RequestHandler):
|
class RPCRequestHandler(RequestHandler):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -78,10 +78,16 @@ class KVCacheManager:
|
||||||
self.max_output_length = config.max_output_len
|
self.max_output_length = config.max_output_len
|
||||||
# Cache block settings
|
# Cache block settings
|
||||||
self.block_size = config.block_size
|
self.block_size = config.block_size
|
||||||
|
|
||||||
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
||||||
self.max_blocks_per_sequence = (
|
if config.enable_streamingllm:
|
||||||
self.max_input_length + self.max_output_length + self.block_size - 1
|
self.max_blocks_per_sequence = (
|
||||||
) // self.block_size
|
config.start_token_size + config.generated_token_size + self.block_size - 1
|
||||||
|
) // self.block_size + 1
|
||||||
|
else:
|
||||||
|
self.max_blocks_per_sequence = (
|
||||||
|
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||||
|
) // self.block_size
|
||||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||||
|
|
||||||
# Physical cache allocation
|
# Physical cache allocation
|
||||||
|
@ -446,6 +452,20 @@ class KVCacheManager:
|
||||||
self._available_blocks = self.num_blocks
|
self._available_blocks = self.num_blocks
|
||||||
self._block_states[:] = 1
|
self._block_states[:] = 1
|
||||||
|
|
||||||
|
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
|
||||||
|
"""
|
||||||
|
Free the block that needs to be swapped out.
|
||||||
|
"""
|
||||||
|
for global_block_id in updated_block_ids:
|
||||||
|
if global_block_id < 0:
|
||||||
|
return
|
||||||
|
block: CacheBlock = self._cache_blocks[global_block_id]
|
||||||
|
block.remove_ref()
|
||||||
|
if not block.has_ref():
|
||||||
|
block.allocated_size = 0
|
||||||
|
self._available_blocks += 1
|
||||||
|
self._block_states[global_block_id] = 1
|
||||||
|
|
||||||
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
|
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
|
||||||
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
|
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
|
||||||
|
@ -533,10 +553,16 @@ class RPCKVCacheManager(KVCacheManager):
|
||||||
self.max_output_length = config.max_output_len
|
self.max_output_length = config.max_output_len
|
||||||
# Cache block settings
|
# Cache block settings
|
||||||
self.block_size = config.block_size
|
self.block_size = config.block_size
|
||||||
|
|
||||||
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
||||||
self.max_blocks_per_sequence = (
|
if config.enable_streamingllm:
|
||||||
self.max_input_length + self.max_output_length + self.block_size - 1
|
self.max_blocks_per_sequence = (
|
||||||
) // self.block_size
|
config.start_token_size + config.generated_token_size + self.block_size - 1
|
||||||
|
) // self.block_size + 1
|
||||||
|
else:
|
||||||
|
self.max_blocks_per_sequence = (
|
||||||
|
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||||
|
) // self.block_size
|
||||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||||
|
|
||||||
# Logical cache blocks allocation
|
# Logical cache blocks allocation
|
||||||
|
|
|
@ -48,6 +48,9 @@ def infer(args):
|
||||||
block_size=16,
|
block_size=16,
|
||||||
tp_size=args.tp_size,
|
tp_size=args.tp_size,
|
||||||
use_cuda_kernel=args.use_cuda_kernel,
|
use_cuda_kernel=args.use_cuda_kernel,
|
||||||
|
enable_streamingllm=args.enable_streamingllm,
|
||||||
|
start_token_size=args.start_token_size,
|
||||||
|
generated_token_size=args.generated_token_size,
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(f"Initializing Inference Engine...")
|
coordinator.print_on_master(f"Initializing Inference Engine...")
|
||||||
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
|
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||||
|
@ -63,6 +66,8 @@ def infer(args):
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
|
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
||||||
|
repetition_penalty=args.repetition_penalty,
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(f"Generating...")
|
coordinator.print_on_master(f"Generating...")
|
||||||
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
|
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
|
||||||
|
@ -107,6 +112,25 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
||||||
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
|
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
|
||||||
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
|
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
|
||||||
|
parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM")
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM,"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_repeat_ngram_size",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repetition_penalty",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
infer(args)
|
infer(args)
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.multiprocessing import Manager
|
||||||
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||||
|
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def check_streamingllm():
|
||||||
|
setup_seed(20)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
model = LlamaForCausalLM(
|
||||||
|
LlamaConfig(
|
||||||
|
vocab_size=50000,
|
||||||
|
hidden_size=512,
|
||||||
|
intermediate_size=1536,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_hidden_layers=16,
|
||||||
|
)
|
||||||
|
).cuda()
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
input_token_ids = data_gen(1, 4)
|
||||||
|
|
||||||
|
output_len = 128
|
||||||
|
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
max_batch_size=1,
|
||||||
|
max_output_len=output_len,
|
||||||
|
dtype="fp32",
|
||||||
|
use_cuda_kernel=True,
|
||||||
|
enable_streamingllm=True,
|
||||||
|
start_token_size=4,
|
||||||
|
generated_token_size=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
|
inference_engine.add_request(prompts_token_ids=input_token_ids)
|
||||||
|
assert inference_engine.request_handler._has_waiting()
|
||||||
|
|
||||||
|
assert inference_config.start_token_size == inference_config.block_size
|
||||||
|
|
||||||
|
request_handler = inference_engine.request_handler
|
||||||
|
running_bb = request_handler.running_bb
|
||||||
|
|
||||||
|
for _ in range(12):
|
||||||
|
inference_engine.step()
|
||||||
|
|
||||||
|
assert running_bb.block_tables[0].tolist() == [0, -1, -1, -1]
|
||||||
|
assert running_bb.seq_lengths[0].item() == 16
|
||||||
|
|
||||||
|
for _ in range(16):
|
||||||
|
inference_engine.step()
|
||||||
|
|
||||||
|
assert running_bb.block_tables[0].tolist() == [0, 1, -1, -1]
|
||||||
|
assert running_bb.seq_lengths[0].item() == 32
|
||||||
|
|
||||||
|
for _ in range(16):
|
||||||
|
inference_engine.step()
|
||||||
|
|
||||||
|
assert running_bb.block_tables[0].tolist() == [0, 1, 2, -1]
|
||||||
|
assert running_bb.seq_lengths[0].item() == 48
|
||||||
|
|
||||||
|
for _ in range(16):
|
||||||
|
inference_engine.step()
|
||||||
|
|
||||||
|
assert running_bb.block_tables[0].tolist() == [0, 2, 3, -1]
|
||||||
|
assert running_bb.seq_lengths[0].item() == 48
|
||||||
|
|
||||||
|
for _ in range(1):
|
||||||
|
inference_engine.step()
|
||||||
|
|
||||||
|
assert running_bb.block_tables[0].tolist() == [0, 2, 3, 1]
|
||||||
|
assert running_bb.seq_lengths[0].item() == 49
|
||||||
|
|
||||||
|
for _ in range(15):
|
||||||
|
inference_engine.step()
|
||||||
|
|
||||||
|
assert running_bb.block_tables[0].tolist() == [0, 3, 1, -1]
|
||||||
|
assert running_bb.seq_lengths[0].item() == 48
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
ret[rank] = func_to_run(**kwargs)
|
||||||
|
else:
|
||||||
|
func_to_run(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_engine():
|
||||||
|
manager = Manager()
|
||||||
|
result_list = manager.list([-1] * 1) # Create a shared list
|
||||||
|
|
||||||
|
spawn(run_dist, 1, func_to_run=check_streamingllm, ret=result_list)
|
||||||
|
return result_list[0]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_engine()
|
Loading…
Reference in New Issue