import pytest import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.struct import RequestStatus, Sequence from colossalai.testing import rerun_if_address_is_in_use, spawn def check_config_and_inference(): config = InferenceConfig() assert config.max_batch_size == 8 sequence = Sequence( request_id=1, prompt="abc", input_token_id=[1, 2, 3], block_size=16, sample_params=None, eos_token_id=2, pad_token_id=2, max_output_len=256, ) sequence.mark_running() assert sequence.status == RequestStatus.RUNNING sequence.recycle() assert sequence.status == RequestStatus.RECYCLED assert sequence.sentence_len == 3 assert sequence.input_len == 3 assert sequence.output_len == 0 assert sequence.check_finish() == False def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_config_and_inference() @pytest.mark.dist @rerun_if_address_is_in_use() def test_config_and_inference(): spawn(run_dist, 1) if __name__ == "__main__": test_config_and_inference()