diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 5bdadca78..27ab7c76a 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -91,7 +91,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny - timeout-minutes: 60 + timeout-minutes: 75 defaults: run: shell: bash diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 732adf56a..abecd4886 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -81,7 +81,7 @@ import colossalai from colossalai.inference import InferenceEngine, InferenceConfig from pprint import pprint -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() # Step 1: create a model in "transformers" way model_path = "lmsys/vicuna-7b-v1.3" diff --git a/colossalai/inference/spec/README.md b/colossalai/inference/spec/README.md index 96ae1622d..d6faaea2e 100644 --- a/colossalai/inference/spec/README.md +++ b/colossalai/inference/spec/README.md @@ -23,7 +23,7 @@ from colossalai.inference.core.engine import InferenceEngine, GenerationConfig from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig # launch colossalai, setup distributed environment -colossalai.launch_from_torch(config={}) +colossalai.launch_from_torch() # main model model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD" diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index fade655e1..148b2bf88 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -1,11 +1,7 @@ import enum from dataclasses import dataclass -from typing import Any, List, Tuple, Union +from typing import Any, List -import torch -from ordered_set import OrderedSet - -from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -170,242 +166,6 @@ class Sequence: ) -@dataclass -class BatchInfo: - """ - Information to be passed and used for a batch of sequences. - """ - - max_batch_size: int - kv_max_split_num: int - num_heads: int - head_dim: int - sequences_set: OrderedSet[Sequence] = None - is_prompts: bool = True - device: torch.device = None - dtype: torch.dtype = None - fd_inter_tensor: FDIntermTensors = None - - def __post_init__(self): - if self.device is None: - self.device = torch.cuda.current_device() - if self.sequences_set is None: - self.sequences_set = OrderedSet() - if self.fd_inter_tensor is None: - self.fd_inter_tensor = FDIntermTensors() - - def init_fd_tensors(self): - if not self.fd_inter_tensor.is_initialized: - self.fd_inter_tensor.initialize( - max_batch_size=self.max_batch_size, - num_attn_heads=self.num_heads, - kv_max_split_num=self.kv_max_split_num, - head_dim=self.head_dim, - dtype=self.dtype, - device=self.device, - ) - - def get_block_table_tensor(self) -> None: - tesnor_list = [] - block_table = None - - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." - - for seq in self.sequences_set: - block_table = seq.block_table - assert ( - block_table is not None - ), f"The sequence(request_id {seq.request_id}) has not initialized the block_table." - tesnor_list.append(seq.block_table) - - block_table = torch.stack(tesnor_list) - return block_table - - def clear_batch(self) -> None: - """ - Clear sequence set and block table if we need to abort this batch. - Prefill: clear sequence set and move them to running batch(external) - Decoding: mark unfinished sequences as aborted. - """ - if self.is_prompts: - self.sequences_set.clear() - else: - for seq in self.sequences_set: - seq.mark_aborted() - if seq.check_finish(): - seq.mark_finished() - - self.sequences_set.clear() - - def fliter_batch(self) -> List["Sequence"]: - """ - Remove completed sentences from a batch. - - Returns: - List["Sequence"]: List of finished sequences. - """ - finish_seqs = [] - for seq in self.sequences_set: - if seq.check_finish(): - finish_seqs.append(seq) - for finish_seq in finish_seqs: - self.sequences_set.discard(finish_seq) - return finish_seqs - - def abort_seq(self, seq: "Sequence") -> "Sequence": - """ - Remove sequence from the batch. - """ - if not seq.check_finish(): - seq.status = RequestStatus.ABORTED - self.sequences_set.discard(seq) - return seq - - def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None: - """ - Add new sequence to batch - - Args: - seqs (List["Sequence"]): The list of new sequences. - """ - # covnert single sequence to list - if isinstance(seqs, Sequence): - seqs = [seqs] - - for seq in seqs: - if seq in self.sequences_set: - logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") - continue - self.sequences_set.add(seq) - - def del_seq(self, seq: Sequence) -> Sequence: - """ - Delete sequence in batch - """ - self.sequences_set.discard(seq) - - @property - def is_empty(self) -> None: - """ - Check whether sequences_set is empty. - """ - return not self.sequences_set - - def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None: - """ - Add an output token for each sentence in the batch. - - Args: - tokens (List[int]): A batch of tokens - """ - - if isinstance(tokens, torch.Tensor): - tokens = tokens.tolist() - - assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size." - - for seq, token in zip(self.sequences_set, tokens): - if not isinstance(token, list): - if not isinstance(token, int): - raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.") - token = [token] - seq.output_token_id += token - seq.check_finish() - - def get_batch_size(self) -> int: - """ - Get batch_size of this batch - """ - return len(self.sequences_set) - - def get_batch_inputs(self) -> torch.LongTensor: - """ - Get bacth inputs for forward inference computation. - """ - - input_list = [] - - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." - - for seq in self.sequences_set: - if self.is_prompts: - if seq.output_len > 0: - input_list.append(seq.input_token_id + seq.output_token_id) - else: - input_list.append(seq.input_token_id) - else: - input_list.append([seq.output_token_id[-1]]) - - max_seq_len = max(len(sub_list) for sub_list in input_list) - - # We assume that all the padding_id in seq are the same at present. - return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int) - - def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: - """ - Flattening the input tokens. - """ - input_list = [] - - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." - - for seq in self.sequences_set: - if self.is_prompts: - input_list.extend(seq.input_token_id) - else: - input_list.append(seq.output_token_id[-1]) - - return torch.tensor(input_list, dtype=torch.long, device=self.device) - - def get_sequence_lengths(self): - """ - Get the input_len of each sentence in this batch. - """ - len_list = [] - - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." - - for seq in self.sequences_set: - len_list.append(seq.sentence_len) - - return torch.tensor(len_list, dtype=torch.int, device=self.device) - - def get_attn_mask(self) -> torch.Tensor: - """ - Generate and return attention mask. - """ - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." - - past_values = [] - # We assume that all the padding_id in seq are the same at present. - padding_id = self.sequences_set[0].pad_token_id - - for seq in self.sequences_set: - past_values.append(seq.input_token_id + seq.output_token_id) - - max_seq_len = max(len(sub_list) for sub_list in past_values) - attn_mask = _make_tensor_with_pad( - past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device - ) - - return attn_mask.ne(padding_id).long() - - def __repr__(self) -> str: - return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" - - def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return [pad] * (max_len - len(x)) + x - - -def _make_tensor_with_pad( - x: Union[List[List[int]], List[int]], - max_len: int, - pad: int, - dtype: torch.dtype, - device: Union[str, torch.device] = "cuda", - pin_memory: bool = False, -): - padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] - return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu") diff --git a/examples/inference/benchmark_ops/test_ci.sh b/examples/inference/benchmark_ops/test_ci.sh new file mode 100644 index 000000000..e69de29bb diff --git a/examples/inference/benchmark_llama.py b/examples/inference/llama/benchmark_llama.py similarity index 100% rename from examples/inference/benchmark_llama.py rename to examples/inference/llama/benchmark_llama.py diff --git a/examples/inference/benchmark_llama3.py b/examples/inference/llama/benchmark_llama3.py similarity index 98% rename from examples/inference/benchmark_llama3.py rename to examples/inference/llama/benchmark_llama3.py index 2829090f0..07ebdb2b1 100644 --- a/examples/inference/benchmark_llama3.py +++ b/examples/inference/llama/benchmark_llama3.py @@ -182,7 +182,7 @@ def benchmark_inference(args): def inference(rank, world_size, port, args): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") benchmark_inference(args) diff --git a/examples/inference/llama_generation.py b/examples/inference/llama/llama_generation.py similarity index 96% rename from examples/inference/llama_generation.py rename to examples/inference/llama/llama_generation.py index 83ed7a6bc..5a373dccd 100644 --- a/examples/inference/llama_generation.py +++ b/examples/inference/llama/llama_generation.py @@ -17,7 +17,7 @@ def infer(args): # ============================== # Launch colossalai, setup distributed environment # ============================== - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch() coordinator = DistCoordinator() # ============================== @@ -59,7 +59,7 @@ def infer(args): coordinator.print_on_master(out[0]) -# colossalai run --nproc_per_node 1 llama_gen.py -m MODEL_PATH +# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH if __name__ == "__main__": # ============================== # Parse Arguments diff --git a/examples/inference/run_benchmark.sh b/examples/inference/llama/run_benchmark.sh similarity index 100% rename from examples/inference/run_benchmark.sh rename to examples/inference/llama/run_benchmark.sh diff --git a/examples/inference/llama/test_ci.sh b/examples/inference/llama/test_ci.sh new file mode 100644 index 000000000..b130fc486 --- /dev/null +++ b/examples/inference/llama/test_ci.sh @@ -0,0 +1,4 @@ +#!/bin/bash +echo "Skip the test (this test is slow)" + +# bash ./run_benchmark.sh diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 709e82baa..fdd8442f5 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -35,7 +35,7 @@ from transformers.utils import ( replace_return_docstrings, ) -from colossalai.kernel.extensions.pybind.flash_attention import HAS_FLASH_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt deleted file mode 100644 index b05cafc67..000000000 --- a/requirements/requirements-infer.txt +++ /dev/null @@ -1,2 +0,0 @@ -ordered_set -transformers==4.36.2 diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index bb97a2a3a..58c7f780f 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,6 +1,4 @@ diffusers -fbgemm-gpu==0.2.0 -ordered_set pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon diff --git a/tests/test_infer/__init__.py b/tests/test_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index cc0389af9..d6f542129 100755 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -2,7 +2,7 @@ import pytest import colossalai from colossalai.inference.config import InferenceConfig -from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence +from colossalai.inference.struct import RequestStatus, Sequence from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -20,27 +20,6 @@ def check_config_and_inference(): max_output_len=256, ) - sequence2 = Sequence( - request_id=2, - prompt="bcd", - input_token_id=[4, 5, 6], - block_size=16, - sample_params=None, - eos_token_id=2, - pad_token_id=2, - max_output_len=256, - ) - - sequence3 = Sequence( - request_id=3, - prompt="efg", - input_token_id=[7, 8, 9], - block_size=16, - sample_params=None, - eos_token_id=2, - pad_token_id=2, - max_output_len=256, - ) sequence.mark_running() assert sequence.status == RequestStatus.RUNNING sequence.recycle() @@ -51,33 +30,6 @@ def check_config_and_inference(): assert sequence.output_len == 0 assert sequence.check_finish() == False - batch = BatchInfo( - max_batch_size=8, - kv_max_split_num=16, - num_heads=2, - head_dim=128, - ) - batch.add_seqs([sequence]) - batch.add_seqs([sequence2, sequence3]) - - # add duplicated sequence to test that it will not be counted twice - batch.add_seqs([sequence]) - - assert batch.is_empty == False - assert batch.get_batch_size() == 3 - batch.update_batch_tokens([1, 2, 3]) - seq = batch.abort_seq(sequence) - seq2 = batch.fliter_batch()[0] - - assert batch.get_batch_size() == 1 - assert seq.output_len == 1 - assert seq.output_token_id == [1] - assert seq2.output_len == 1 - assert seq2.output_token_id == [2] - - batch.clear_batch() - assert batch.is_empty == True - def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") diff --git a/tests/test_infer/test_cuda_graph.py b/tests/test_infer/test_cuda_graph.py index 4cdc62fbe..2be188571 100644 --- a/tests/test_infer/test_cuda_graph.py +++ b/tests/test_infer/test_cuda_graph.py @@ -86,7 +86,7 @@ def run_dist(rank, world_size, port): check_output_consistency(128) -@pytest.mark.dist +@pytest.mark.largedist @rerun_if_address_is_in_use() def test_cuda_graph_infer(): spawn(run_dist, 1) diff --git a/tests/test_infer/test_drafter.py b/tests/test_infer/test_drafter.py index 686229f38..3c5dda157 100644 --- a/tests/test_infer/test_drafter.py +++ b/tests/test_infer/test_drafter.py @@ -11,13 +11,16 @@ MAX_LEN = 100 SPEC_NUM = 5 +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + + @pytest.mark.parametrize("spec_num", [SPEC_NUM]) -def test_drafter(spec_num: int): +def test_drafter(tokenizer, spec_num: int): torch.manual_seed(123) device = get_current_device() - - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS) toy_config.pad_token_id = tokenizer.eos_token_id drafter_model = LlamaForCausalLM(toy_config) @@ -39,10 +42,9 @@ def test_drafter(spec_num: int): assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num -def test_spec_dec(): +def test_spec_dec(tokenizer): spec_num = SPEC_NUM device = get_current_device() - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer.pad_token = tokenizer.eos_token # Dummy config for Glide Model @@ -67,5 +69,6 @@ def test_spec_dec(): if __name__ == "__main__": - test_drafter(spec_num=SPEC_NUM) - test_spec_dec() + dummy_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + test_drafter(dummy_tokenizer, spec_num=SPEC_NUM) + test_spec_dec(dummy_tokenizer) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index a0ddbbc7b..8061c50d2 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -165,8 +165,10 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): func_to_run(**kwargs) +@pytest.mark.largedist @parameterize("prompt_template", [None, "llama"]) @parameterize("do_sample", [False]) +@rerun_if_address_is_in_use() def test_tp_engine(prompt_template, do_sample): kwargs1 = { "use_engine": True, @@ -186,18 +188,14 @@ def test_tp_engine(prompt_template, do_sample): assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" +@pytest.mark.largedist @parameterize("num_layers", [1]) @parameterize("max_length", [64]) +@rerun_if_address_is_in_use() def test_spec_dec(num_layers, max_length): spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length) -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_inference_engine(): +if __name__ == "__main__": test_tp_engine() test_spec_dec() - - -if __name__ == "__main__": - test_inference_engine() diff --git a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py index e34fada97..9d76858ed 100644 --- a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py @@ -86,11 +86,11 @@ def torch_attn_unpad( @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") -@pytest.mark.parametrize("bsz", [4, 7, 32]) -@pytest.mark.parametrize("block_size", [16, 32, 64]) -@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("bsz", [7, 32]) +@pytest.mark.parametrize("block_size", [16, 32]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16]) @pytest.mark.parametrize("num_attn_heads", [16]) -@pytest.mark.parametrize("kv_group_num", [1, 2, 16]) +@pytest.mark.parametrize("kv_group_num", [1, 4]) @pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("use_alibi_slopes", [True, False]) @pytest.mark.parametrize("use_new_kcache_layout", [True, False]) diff --git a/tests/test_infer/test_kernels/triton/test_decoding_attn.py b/tests/test_infer/test_kernels/triton/test_decoding_attn.py index 24741fecf..e487129c1 100644 --- a/tests/test_infer/test_kernels/triton/test_decoding_attn.py +++ b/tests/test_infer/test_kernels/triton/test_decoding_attn.py @@ -68,11 +68,11 @@ def prepare_data( @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") -@pytest.mark.parametrize("bsz", [4, 7, 32]) -@pytest.mark.parametrize("block_size", [16, 32, 64]) -@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("bsz", [7, 16]) +@pytest.mark.parametrize("block_size", [16, 32]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16]) @pytest.mark.parametrize("num_attn_heads", [16]) -@pytest.mark.parametrize("kv_group_num", [1, 2, 16]) +@pytest.mark.parametrize("kv_group_num", [1, 4]) @pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("q_len", [1, 5]) @pytest.mark.parametrize("use_alibi_slopes", [True, False]) @@ -187,7 +187,7 @@ def test_flash_decoding( rtol = 1e-4 # After the shape becomes larger, some data elements are too small, leading to excessively large relative errors. - if bsz == 32 and use_alibi_slopes: + if bsz >= 16 and use_alibi_slopes: rtol = 100 numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol) diff --git a/tests/test_infer/test_kernels/triton/test_kvcache_copy.py b/tests/test_infer/test_kernels/triton/test_kvcache_copy.py index 336eb256b..4aa34ae30 100644 --- a/tests/test_infer/test_kernels/triton/test_kvcache_copy.py +++ b/tests/test_infer/test_kernels/triton/test_kvcache_copy.py @@ -70,9 +70,9 @@ def prepare_data( @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") -@pytest.mark.parametrize("bsz", [4, 7, 32]) +@pytest.mark.parametrize("bsz", [7, 32]) @pytest.mark.parametrize("block_size", [16, 32, 64]) -@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [16]) @pytest.mark.parametrize("num_kv_heads", [16]) @pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("n_tokens", [1, 5]) diff --git a/tests/test_infer/test_models/test_attention.py b/tests/test_infer/test_models/test_attention.py index 1091370ce..79ed6675d 100644 --- a/tests/test_infer/test_models/test_attention.py +++ b/tests/test_infer/test_models/test_attention.py @@ -1,3 +1,4 @@ +import pytest import torch from transformers.cache_utils import DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter @@ -7,6 +8,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache +@pytest.mark.skip(reason="This test is not used in the current version.") def test_copy_to_cache(): key = torch.ones((2, 11, 3, 3)) key[0, 9, :, :] = 0 @@ -24,6 +26,7 @@ def test_copy_to_cache(): assert cache[3, 0, 0, 0] == 1 +@pytest.mark.skip(reason="This test is not used in the current version.") def test_convert_kvcache(): cache = torch.ones(8, 3, 8, 3) key = torch.ones(2, 1, 3, 3) + 1 @@ -34,6 +37,7 @@ def test_convert_kvcache(): assert converted_cache.shape == (2, 10, 3, 3) +@pytest.mark.skip(reason="This test is not used in the current version.") def test_context_attention(): """ test config: head_num = 4, head_size = 4 @@ -86,6 +90,7 @@ def test_context_attention(): assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3) +@pytest.mark.skip(reason="This test is not used in the current version.") def test_decoding_attention(): # test the pipeline of decoding attention attn = PagedAttention() diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py index 3d6fc3bdb..736fab5ff 100644 --- a/tests/test_infer/test_models/test_baichuan.py +++ b/tests/test_infer/test_models/test_baichuan.py @@ -128,7 +128,7 @@ def check_tp_engine(prompt_template, do_sample, use_cuda_kernel): not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH), reason="There is no local model address included, please replace this address with a valid one.", ) -@pytest.mark.dist +@pytest.mark.largedist @rerun_if_address_is_in_use() def test_inference_engine(): check_tp_engine()