mirror of https://github.com/hpcaitech/ColossalAI
[Fix] Fix & Update Inference Tests (compatibility w/ main)
parent
56ed09aba5
commit
8754abae24
|
@ -270,7 +270,7 @@ def llama_rmsnorm_forward(
|
|||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
|
||||
class NopadLlamaMLP(ParallelModule, LlamaMLP):
|
||||
class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
|
@ -392,7 +392,7 @@ class NopadLlamaMLP(ParallelModule, LlamaMLP):
|
|||
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
|
||||
|
||||
|
||||
class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
|
|
|
@ -4,7 +4,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
|
|
@ -2,14 +2,14 @@ import torch
|
|||
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
|
||||
from tests.test_infer.test_kernels.triton.test_decoding_attn import prepare_data
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
generate_caches_and_block_tables_vllm,
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
mock_alloc_block_table_and_kvcache_v2,
|
||||
mock_alloc_block_table_and_kvcache_v3,
|
||||
mock_alloc_single_token,
|
||||
|
|
|
@ -4,8 +4,8 @@ from colossalai.inference.modeling.layers.attention import copy_to_cache
|
|||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
|
||||
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
|
||||
from tests.test_infer.test_kernels.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
|
||||
from tests.test_infer.test_kernels.triton.test_kvcache_copy import prepare_data
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from colossalai.kernel.triton import get_xine_cache
|
||||
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
|
||||
from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
|
|
@ -80,7 +80,7 @@ def check_config_and_inference():
|
|||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_config_and_inference()
|
||||
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ def check_output_consistency(batch_size):
|
|||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_output_consistency(32)
|
||||
check_output_consistency(64)
|
||||
check_output_consistency(128)
|
||||
|
|
|
@ -157,7 +157,7 @@ def check_spec_dec(num_layers, max_length):
|
|||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
|
|
|
@ -7,11 +7,11 @@ import torch
|
|||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v3,
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
|
||||
from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
|
@ -4,7 +4,10 @@ import torch.nn.functional as F
|
|||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v3,
|
||||
mock_alloc_single_token,
|
||||
)
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
|
@ -7,8 +7,8 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
|||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
|
||||
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
|
||||
from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||
|
||||
|
||||
def numpy_allclose(x, y, rtol, atol):
|
|
@ -5,7 +5,7 @@ from packaging import version
|
|||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
torch_attn_ref,
|
|
@ -6,14 +6,14 @@ from packaging import version
|
|||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
||||
try:
|
||||
import triton # noqa
|
|
@ -4,7 +4,7 @@ from packaging import version
|
|||
|
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
mock_alloc_single_token,
|
|
@ -4,7 +4,7 @@ from packaging import version
|
|||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import decoding_fused_rotary_embedding
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
mock_alloc_block_table_and_kvcache_v2,
|
||||
mock_alloc_block_table_and_kvcache_v3,
|
||||
)
|
|
@ -164,7 +164,7 @@ def check_cache_manager(test_config):
|
|||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_cache_manager()
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ from colossalai.inference.core.engine import InferenceEngine
|
|||
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
|
||||
|
||||
|
||||
|
@ -87,7 +86,7 @@ def run_engine(world_size, **kwargs):
|
|||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
|
@ -99,7 +98,7 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
|||
@parameterize("prompt_template", [None, "baichuan"])
|
||||
@parameterize("do_sample", [False])
|
||||
@parameterize("use_cuda_kernel", [True])
|
||||
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
kwargs1 = {
|
||||
"use_engine": True,
|
||||
"prompt_template": prompt_template,
|
||||
|
@ -132,7 +131,7 @@ def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_inference_engine():
|
||||
test_tp_engine()
|
||||
check_tp_engine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -90,7 +90,7 @@ def check_request_handler():
|
|||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_running_list()
|
||||
check_request_handler()
|
||||
|
||||
|
|
Loading…
Reference in New Issue