import pytest import torch from transformers.models.llama import LlamaConfig import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.core.request_handler import RequestHandler, RunningList from colossalai.inference.struct import RequestStatus, Sequence from colossalai.testing import rerun_if_address_is_in_use, spawn def check_running_list(): """ Test the RunningList Structure. """ running_list = RunningList(prefill_ratio=1.2) seq1 = Sequence( request_id=1, prompt="abc", input_token_id=[1, 2, 3], block_size=16, eos_token_id=0, pad_token_id=0, sample_params=None, block_table=1, ) running_list.append(seq1) assert running_list.ready_for_prefill() assert running_list.decoding == [] and running_list.prefill[0] == seq1 seq = running_list.find_seq(seq1.request_id) assert seq == seq1 running_list.remove(seq1) assert running_list.is_empty() def check_request_handler(): """ Test main function of RequestHandler """ inference_config = InferenceConfig( max_input_len=10, max_output_len=10, block_size=8, ) model_config = LlamaConfig( hidden_size=32, num_hidden_layers=2, num_attention_heads=4, ) request_handler = RequestHandler(inference_config, model_config) seq1 = Sequence( request_id=1, prompt="abc", input_token_id=[1, 2, 3, 4, 5], block_size=16, eos_token_id=0, pad_token_id=0, sample_params=None, block_table=torch.tensor([-1, -1]), ) request_handler.add_sequence(seq1) # the priority should be 1 assert request_handler.waiting_list[1][0] == seq1 assert request_handler._has_waiting() request_handler.abort_sequence(seq1.request_id) assert not request_handler._has_waiting() seq1.status = RequestStatus.WAITING request_handler.add_sequence(seq1) request_handler.schedule() def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") check_running_list() check_request_handler() @pytest.mark.dist @rerun_if_address_is_in_use() def test_running_list_and_request_handler(): spawn(run_dist, 1) if __name__ == "__main__": test_running_list_and_request_handler()