import random import numpy as np import pytest import torch from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.core.engine import InferenceEngine from colossalai.testing import rerun_if_address_is_in_use, spawn def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def check_inference_engine(use_cuda_graph=False, batch_size=32): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") model = ( LlamaForCausalLM( LlamaConfig( vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 ) ) .cuda() .half() ) model = model.eval() prompts_token_ids = [] for i in range(batch_size): prompts_token_ids.append( np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist() ) input_len = 1024 output_len = 128 do_sample = True top_p = 0.5 top_k = 50 if use_cuda_graph: inference_config = InferenceConfig( max_batch_size=batch_size, max_input_len=input_len, max_output_len=output_len, use_cuda_kernel=False, use_cuda_graph=True, block_size=16, ) else: inference_config = InferenceConfig( max_batch_size=batch_size, max_input_len=input_len, max_output_len=output_len, use_cuda_kernel=False, use_cuda_graph=False, block_size=16, ) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config) return outputs def check_output_consistency(batch_size): cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size) naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size) for s1, s2 in zip(cuda_graph_output, naive_model_output): assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}" def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") check_output_consistency(32) check_output_consistency(64) check_output_consistency(128) @pytest.mark.dist @rerun_if_address_is_in_use() def test_cuda_graph_infer(): spawn(run_dist, 1) if __name__ == "__main__": test_cuda_graph_infer()