[inference] moved ops tests to test_infer (#5354)

pull/5356/head
Frank Lee 2024-02-02 13:51:22 +08:00 committed by GitHub
parent db1a763307
commit e76acbb076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 7 additions and 4 deletions

View File

@ -63,6 +63,9 @@ def check_config_and_inference():
batch.add_seqs([sequence])
batch.add_seqs([sequence2, sequence3])
# add duplicated sequence to test that it will not be counted twice
batch.add_seqs([sequence])
assert batch.is_empty == False
assert batch.get_batch_size() == 3
batch.update_batch_tokens([1, 2, 3])

View File

@ -6,7 +6,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
try:
import triton # noqa

View File

@ -4,7 +4,7 @@ from packaging import version
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer_ops.triton.kernel_utils import (
from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded,
generate_caches_and_block_tables_v2,
prepare_padding_mask,

View File

@ -5,7 +5,7 @@ from packaging import version
from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
try:
import triton # noqa

View File

@ -4,7 +4,7 @@ from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.triton import rotary_embedding
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
try:
import triton # noqa