Browse Source

[Fix/Inference] Fix format of input prompts and input model in inference engine (#5395)

* Fix bugs in inference_engine

* fix bugs in engine.py

* rm  CUDA_VISIBLE_DEVICES

* add request_ids in generate

* fix bug in engine.py

* add logger.debug for BatchBucket
pull/5405/head
yuehuayingxueluo 9 months ago committed by GitHub
parent
commit
bc1da87366
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      colossalai/inference/batch_bucket.py
  2. 24
      colossalai/inference/core/engine.py
  3. 2
      colossalai/inference/struct.py
  4. 2
      examples/inference/run_benchmark.sh
  5. 4
      tests/test_infer/test_batch_bucket.py

3
colossalai/inference/batch_bucket.py

@ -447,3 +447,6 @@ class BatchBucket:
def fd_inter_tensor(self) -> None: def fd_inter_tensor(self) -> None:
assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided" assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided"
return self.fd_interm_tensor return self.fd_interm_tensor
def __repr__(self) -> str:
return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"

24
colossalai/inference/core/engine.py

@ -57,6 +57,7 @@ class InferenceEngine:
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config = inference_config.to_generation_config(self.model_config)
model = model.eval() model = model.eval()
model = model.cuda()
model.to(self.dtype) model.to(self.dtype)
if model_policy is None: if model_policy is None:
@ -133,12 +134,13 @@ class InferenceEngine:
) )
shardformer = ShardFormer(shard_config=shardconfig) shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy) shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda() return shard_model
def generate( def generate(
self, self,
prompts: List[str] = None, prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
request_ids: List[int] = None,
return_token_ids: bool = False, return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
) -> List[str]: ) -> List[str]:
@ -148,6 +150,7 @@ class InferenceEngine:
Args: Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
request_ids (List[int], optional): The request ID. Defaults to None.
return_token_ids (bool): Whether to return output token ids. Defaults to False. return_token_ids (bool): Whether to return output token ids. Defaults to False.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
@ -157,7 +160,7 @@ class InferenceEngine:
with torch.inference_mode(): with torch.inference_mode():
self.generation_config = generation_config self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids) self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
output_seqs_list = [] output_seqs_list = []
total_tokens_list = [] total_tokens_list = []
@ -204,7 +207,7 @@ class InferenceEngine:
def add_request( def add_request(
self, self,
requests_id: List[int] = None, request_ids: List[int] = None,
prompts: List[str] = None, prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
) -> None: ) -> None:
@ -212,7 +215,7 @@ class InferenceEngine:
Add requests. Add requests.
Args: Args:
requests_id (List[int], optional): The request ID. Defaults to None. request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
""" """
@ -223,6 +226,9 @@ class InferenceEngine:
block_size = self.inference_config.block_size block_size = self.inference_config.block_size
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
if prompts_token_ids is None: if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
@ -245,8 +251,14 @@ class InferenceEngine:
prompts_num = len(prompts_token_ids) prompts_num = len(prompts_token_ids)
for i in range(prompts_num): for i in range(prompts_num):
if requests_id: if request_ids:
request_id = requests_id[i] if not isinstance(request_ids, list):
request_ids = [request_ids]
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else: else:
request_id = next(self.counter) request_id = next(self.counter)
if prompts == None: if prompts == None:

2
colossalai/inference/struct.py

@ -157,7 +157,7 @@ class Sequence:
f"prompt={self.prompt}, " f"prompt={self.prompt}, "
f"status={self.status.name}, " f"status={self.status.name}, "
f"sample_params={self.sample_params}, " f"sample_params={self.sample_params}, "
f"input_len={self.input_len})," f"input_len={self.input_len},"
f"output_len={self.output_len})" f"output_len={self.output_len})"
) )

2
examples/inference/run_benchmark.sh

@ -27,7 +27,7 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
for input_len in 128 512 1024; do for input_len in 128 512 1024; do
for output_len in 128 256; do for output_len in 128 256; do
for bsz in 16 32 64; do for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --model_path "/home/caidi/llama_model/" | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
done done
done done
done done

4
tests/test_infer/test_batch_bucket.py

@ -5,8 +5,11 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.struct import Sequence from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize from colossalai.testing import parameterize
logger = get_dist_logger(__name__)
@parameterize( @parameterize(
"test_config", "test_config",
@ -83,6 +86,7 @@ def test_bucket(test_config):
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
) )
block_tables = bb.add_seqs([seq1, seq2]) block_tables = bb.add_seqs([seq1, seq2])
logger.debug(f"bb information: {bb}")
assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence) assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence)
assert torch.all(block_tables < 0), "Initialized block_tables should be negative values" assert torch.all(block_tables < 0), "Initialized block_tables should be negative values"

Loading…
Cancel
Save