You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_infer/test_cuda_graph.py

97 lines
2.9 KiB

import random
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_cuda_graph=False, batch_size=32):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
.cuda()
.half()
)
model = model.eval()
prompts_token_ids = []
for i in range(batch_size):
9 months ago
prompts_token_ids.append(
np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist()
)
input_len = 1024
output_len = 128
do_sample = True
top_p = 0.5
top_k = 50
if use_cuda_graph:
inference_config = InferenceConfig(
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_graph=True,
block_size=16,
)
else:
inference_config = InferenceConfig(
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_graph=False,
block_size=16,
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config)
# print(f"outputs, use_cuda_grpah is {use_cuda_graph}, output: {outputs}")
return outputs
def check_output_consistency(batch_size):
cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size)
naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size)
for s1, s2 in zip(cuda_graph_output, naive_model_output):
assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}"
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency(32)
check_output_consistency(64)
check_output_consistency(128)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cuda_graph_infer():
spawn(run_dist, 1)
if __name__ == "__main__":
test_cuda_graph_infer()