[Fix] Fix & Update Inference Tests (compatibility w/ main)

pull/5685/head
Yuanheng Zhao 2024-05-05 16:28:56 +00:00
parent 56ed09aba5
commit 8754abae24
30 changed files with 32 additions and 30 deletions

View File

@ -270,7 +270,7 @@ def llama_rmsnorm_forward(
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
class NopadLlamaMLP(ParallelModule, LlamaMLP):
class NopadLlamaMLP(LlamaMLP, ParallelModule):
def __init__(
self,
config: LlamaConfig,
@ -392,7 +392,7 @@ class NopadLlamaMLP(ParallelModule, LlamaMLP):
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
class NopadLlamaAttention(ParallelModule, LlamaAttention):
class NopadLlamaAttention(LlamaAttention, ParallelModule):
def __init__(
self,
config: LlamaConfig,

View File

@ -4,7 +4,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
from tests.test_infer.test_kernels.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
try:
import triton # noqa

View File

@ -2,14 +2,14 @@ import torch
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
from tests.test_infer.test_kernels.triton.test_decoding_attn import prepare_data
try:
import triton # noqa

View File

@ -3,7 +3,7 @@ import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
generate_caches_and_block_tables_vllm,

View File

@ -2,7 +2,7 @@ import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import (
from tests.test_infer.test_kernels.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
mock_alloc_single_token,

View File

@ -4,8 +4,8 @@ from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
from tests.test_infer.test_kernels.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
from tests.test_infer.test_kernels.triton.test_kvcache_copy import prepare_data
try:
import triton # noqa

View File

@ -1,7 +1,7 @@
import torch
from colossalai.kernel.triton import get_xine_cache
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
try:
import triton # noqa

View File

@ -80,7 +80,7 @@ def check_config_and_inference():
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_config_and_inference()

View File

@ -80,7 +80,7 @@ def check_output_consistency(batch_size):
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency(32)
check_output_consistency(64)
check_output_consistency(128)

View File

@ -157,7 +157,7 @@ def check_spec_dec(num_layers, max_length):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)

View File

@ -7,11 +7,11 @@ import torch
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_ops.triton.kernel_utils import (
from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v3,

View File

@ -3,7 +3,7 @@ import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
inference_ops = InferenceOpsLoader().load()

View File

@ -4,7 +4,10 @@ import torch.nn.functional as F
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
)
inference_ops = InferenceOpsLoader().load()

View File

@ -7,8 +7,8 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb
def numpy_allclose(x, y, rtol, atol):

View File

@ -5,7 +5,7 @@ from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,

View File

@ -6,14 +6,14 @@ from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
try:
import triton # noqa

View File

@ -4,7 +4,7 @@ from packaging import version
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
mock_alloc_single_token,

View File

@ -4,7 +4,7 @@ from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.triton import decoding_fused_rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import (
from tests.test_infer.test_kernels.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
)

View File

@ -164,7 +164,7 @@ def check_cache_manager(test_config):
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_cache_manager()

View File

@ -14,7 +14,6 @@ from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
@ -87,7 +86,7 @@ def run_engine(world_size, **kwargs):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
@ -99,7 +98,7 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
@ -132,7 +131,7 @@ def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
test_tp_engine()
check_tp_engine()
if __name__ == "__main__":

View File

@ -90,7 +90,7 @@ def check_request_handler():
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_running_list()
check_request_handler()