You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_infer/test_inference_engine.py

84 lines
2.6 KiB

import random
import numpy as np
import pytest
import torch
import transformers
from transformers import AutoTokenizer, GenerationConfig
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
11 months ago
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
11 months ago
def check_inference_engine(test_cai=False):
setup_seed(20)
11 months ago
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
11 months ago
).cuda()
model = model.eval()
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,",
]
11 months ago
output_len = 38
do_sample = True
top_p = 0.5
top_k = 50
11 months ago
if test_cai:
inference_config = InferenceConfig(max_output_len=output_len)
11 months ago
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
11 months ago
outputs = inference_engine.generate(generation_config)
else:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
11 months ago
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
11 months ago
)
11 months ago
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
11 months ago
return outputs
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
cai_outputs = check_inference_engine(True)
transformer_outputs = check_inference_engine(False)
11 months ago
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
spawn(run_dist, 1)
if __name__ == "__main__":
test_inference_engine()