[inference] Fix running time of test_continuous_batching (#5750)

pull/5668/head
Yuanheng Zhao 2024-05-24 19:34:15 +08:00 committed by GitHub
parent 5f8c0a0ac3
commit b96c6390f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 24 additions and 56 deletions

View File

@ -3,10 +3,10 @@ import random
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import colossalai import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -28,69 +28,37 @@ def generate_inputs(num_sequences, min_length, max_length):
return sequences return sequences
@parameterize( @parameterize("n_multiple", [10])
"test_config", @parameterize("max_batch_size", [8])
[ @parameterize("max_input_len", [128])
{ @parameterize("max_output_len", [128])
"max_batch_size": 8, def check_inference_engine(n_multiple, max_batch_size, max_input_len, max_output_len):
"max_output_len": 512,
"max_input_len": 64,
"do_sample": False,
}
],
)
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
setup_seed(20) setup_seed(20)
max_batch_size = test_config["max_batch_size"]
max_input_len = test_config["max_input_len"] tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
max_output_len = test_config["max_output_len"] model = LlamaForCausalLM(LlamaConfig(num_hidden_layers=2)).cuda()
do_sample = test_config["do_sample"]
top_p = 0.5
top_k = 50
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval() model = model.eval()
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len) inputs_token_ids = generate_inputs(
n_multiple * max_batch_size, min_length=max_input_len // 2, max_length=max_input_len
)
inference_config = InferenceConfig(
max_batch_size=max_batch_size, max_input_len=max_input_len, max_output_len=max_output_len
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == max_output_len
if use_engine: inference_engine.add_request(prompts_token_ids=inputs_token_ids)
inference_config = InferenceConfig( assert inference_engine.request_handler._has_waiting()
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == max_output_len
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=max_output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert len(outputs) == 10 * max_batch_size
outputs = inference_engine.generate()
@parameterize("prompt_template", [None, "llama"]) assert not inference_engine.request_handler._has_waiting()
def check_continuous_batching(prompt_template): assert len(outputs) == n_multiple * max_batch_size
check_inference_engine(use_engine=True, prompt_template=prompt_template)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_continuous_batching() check_inference_engine()
@pytest.mark.dist @pytest.mark.dist