import random import numpy as np import torch from torch.multiprocessing import Manager from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.core.engine import InferenceEngine from colossalai.testing import rerun_if_address_is_in_use, spawn def data_gen(batch_size: int = 4, seq_len: int = 512): input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=torch.cuda.current_device()) return input_ids def setup_seed(seed): torch.manual_seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def check_streamingllm(): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") model = LlamaForCausalLM( LlamaConfig( vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_key_value_heads=2, num_hidden_layers=16, ) ).cuda() model = model.eval() input_token_ids = data_gen(1, 4) output_len = 128 inference_config = InferenceConfig( max_batch_size=1, max_output_len=output_len, dtype="fp32", use_cuda_kernel=True, enable_streamingllm=True, start_token_size=4, generated_token_size=32, ) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts_token_ids=input_token_ids) assert inference_engine.request_handler._has_waiting() assert inference_config.start_token_size == inference_config.block_size request_handler = inference_engine.request_handler running_bb = request_handler.running_bb for _ in range(12): inference_engine.step() assert running_bb.block_tables[0].tolist() == [0, -1, -1, -1] assert running_bb.seq_lengths[0].item() == 16 for _ in range(16): inference_engine.step() assert running_bb.block_tables[0].tolist() == [0, 1, -1, -1] assert running_bb.seq_lengths[0].item() == 32 for _ in range(16): inference_engine.step() assert running_bb.block_tables[0].tolist() == [0, 1, 2, -1] assert running_bb.seq_lengths[0].item() == 48 for _ in range(16): inference_engine.step() assert running_bb.block_tables[0].tolist() == [0, 2, 3, -1] assert running_bb.seq_lengths[0].item() == 48 for _ in range(1): inference_engine.step() assert running_bb.block_tables[0].tolist() == [0, 2, 3, 1] assert running_bb.seq_lengths[0].item() == 49 for _ in range(15): inference_engine.step() assert running_bb.block_tables[0].tolist() == [0, 3, 1, -1] assert running_bb.seq_lengths[0].item() == 48 def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") if ret: ret[rank] = func_to_run(**kwargs) else: func_to_run(**kwargs) @rerun_if_address_is_in_use() def test_engine(): manager = Manager() result_list = manager.list([-1] * 1) # Create a shared list spawn(run_dist, 1, func_to_run=check_streamingllm, ret=result_list) return result_list[0] if __name__ == "__main__": test_engine()