mirror of https://github.com/hpcaitech/ColossalAI
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements * modify example inference struct * add test ci scripts * mark test_infer as submodule * rm deprecated cls & deps * import of HAS_FLASH_ATTN * prune inference tests to be run * prune triton kernel tests * increment pytest timeout mins * revert import path in openmoepull/5697/head
parent
f9afe0addd
commit
55cc7f3df7
|
@ -91,7 +91,7 @@ jobs:
|
|||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
timeout-minutes: 60
|
||||
timeout-minutes: 75
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
|
@ -81,7 +81,7 @@ import colossalai
|
|||
from colossalai.inference import InferenceEngine, InferenceConfig
|
||||
from pprint import pprint
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
# Step 1: create a model in "transformers" way
|
||||
model_path = "lmsys/vicuna-7b-v1.3"
|
||||
|
|
|
@ -23,7 +23,7 @@ from colossalai.inference.core.engine import InferenceEngine, GenerationConfig
|
|||
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig
|
||||
|
||||
# launch colossalai, setup distributed environment
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
# main model
|
||||
model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD"
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Tuple, Union
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
@ -170,242 +166,6 @@ class Sequence:
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchInfo:
|
||||
"""
|
||||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
max_batch_size: int
|
||||
kv_max_split_num: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
sequences_set: OrderedSet[Sequence] = None
|
||||
is_prompts: bool = True
|
||||
device: torch.device = None
|
||||
dtype: torch.dtype = None
|
||||
fd_inter_tensor: FDIntermTensors = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device is None:
|
||||
self.device = torch.cuda.current_device()
|
||||
if self.sequences_set is None:
|
||||
self.sequences_set = OrderedSet()
|
||||
if self.fd_inter_tensor is None:
|
||||
self.fd_inter_tensor = FDIntermTensors()
|
||||
|
||||
def init_fd_tensors(self):
|
||||
if not self.fd_inter_tensor.is_initialized:
|
||||
self.fd_inter_tensor.initialize(
|
||||
max_batch_size=self.max_batch_size,
|
||||
num_attn_heads=self.num_heads,
|
||||
kv_max_split_num=self.kv_max_split_num,
|
||||
head_dim=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_block_table_tensor(self) -> None:
|
||||
tesnor_list = []
|
||||
block_table = None
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert (
|
||||
block_table is not None
|
||||
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
|
||||
block_table = torch.stack(tesnor_list)
|
||||
return block_table
|
||||
|
||||
def clear_batch(self) -> None:
|
||||
"""
|
||||
Clear sequence set and block table if we need to abort this batch.
|
||||
Prefill: clear sequence set and move them to running batch(external)
|
||||
Decoding: mark unfinished sequences as aborted.
|
||||
"""
|
||||
if self.is_prompts:
|
||||
self.sequences_set.clear()
|
||||
else:
|
||||
for seq in self.sequences_set:
|
||||
seq.mark_aborted()
|
||||
if seq.check_finish():
|
||||
seq.mark_finished()
|
||||
|
||||
self.sequences_set.clear()
|
||||
|
||||
def fliter_batch(self) -> List["Sequence"]:
|
||||
"""
|
||||
Remove completed sentences from a batch.
|
||||
|
||||
Returns:
|
||||
List["Sequence"]: List of finished sequences.
|
||||
"""
|
||||
finish_seqs = []
|
||||
for seq in self.sequences_set:
|
||||
if seq.check_finish():
|
||||
finish_seqs.append(seq)
|
||||
for finish_seq in finish_seqs:
|
||||
self.sequences_set.discard(finish_seq)
|
||||
return finish_seqs
|
||||
|
||||
def abort_seq(self, seq: "Sequence") -> "Sequence":
|
||||
"""
|
||||
Remove sequence from the batch.
|
||||
"""
|
||||
if not seq.check_finish():
|
||||
seq.status = RequestStatus.ABORTED
|
||||
self.sequences_set.discard(seq)
|
||||
return seq
|
||||
|
||||
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
|
||||
"""
|
||||
Add new sequence to batch
|
||||
|
||||
Args:
|
||||
seqs (List["Sequence"]): The list of new sequences.
|
||||
"""
|
||||
# covnert single sequence to list
|
||||
if isinstance(seqs, Sequence):
|
||||
seqs = [seqs]
|
||||
|
||||
for seq in seqs:
|
||||
if seq in self.sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
def del_seq(self, seq: Sequence) -> Sequence:
|
||||
"""
|
||||
Delete sequence in batch
|
||||
"""
|
||||
self.sequences_set.discard(seq)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> None:
|
||||
"""
|
||||
Check whether sequences_set is empty.
|
||||
"""
|
||||
return not self.sequences_set
|
||||
|
||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
|
||||
"""
|
||||
Add an output token for each sentence in the batch.
|
||||
|
||||
Args:
|
||||
tokens (List[int]): A batch of tokens
|
||||
"""
|
||||
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.tolist()
|
||||
|
||||
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
||||
|
||||
for seq, token in zip(self.sequences_set, tokens):
|
||||
if not isinstance(token, list):
|
||||
if not isinstance(token, int):
|
||||
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
|
||||
token = [token]
|
||||
seq.output_token_id += token
|
||||
seq.check_finish()
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Get batch_size of this batch
|
||||
"""
|
||||
return len(self.sequences_set)
|
||||
|
||||
def get_batch_inputs(self) -> torch.LongTensor:
|
||||
"""
|
||||
Get bacth inputs for forward inference computation.
|
||||
"""
|
||||
|
||||
input_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
if seq.output_len > 0:
|
||||
input_list.append(seq.input_token_id + seq.output_token_id)
|
||||
else:
|
||||
input_list.append(seq.input_token_id)
|
||||
else:
|
||||
input_list.append([seq.output_token_id[-1]])
|
||||
|
||||
max_seq_len = max(len(sub_list) for sub_list in input_list)
|
||||
|
||||
# We assume that all the padding_id in seq are the same at present.
|
||||
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
|
||||
|
||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""
|
||||
Flattening the input tokens.
|
||||
"""
|
||||
input_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
else:
|
||||
input_list.append(seq.output_token_id[-1])
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def get_sequence_lengths(self):
|
||||
"""
|
||||
Get the input_len of each sentence in this batch.
|
||||
"""
|
||||
len_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
len_list.append(seq.sentence_len)
|
||||
|
||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||
|
||||
def get_attn_mask(self) -> torch.Tensor:
|
||||
"""
|
||||
Generate and return attention mask.
|
||||
"""
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
past_values = []
|
||||
# We assume that all the padding_id in seq are the same at present.
|
||||
padding_id = self.sequences_set[0].pad_token_id
|
||||
|
||||
for seq in self.sequences_set:
|
||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
max_seq_len = max(len(sub_list) for sub_list in past_values)
|
||||
attn_mask = _make_tensor_with_pad(
|
||||
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
return attn_mask.ne(padding_id).long()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
assert len(x) <= max_len
|
||||
return [pad] * (max_len - len(x)) + x
|
||||
|
||||
|
||||
def _make_tensor_with_pad(
|
||||
x: Union[List[List[int]], List[int]],
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")
|
||||
|
|
|
@ -182,7 +182,7 @@ def benchmark_inference(args):
|
|||
|
||||
|
||||
def inference(rank, world_size, port, args):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
benchmark_inference(args)
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ def infer(args):
|
|||
# ==============================
|
||||
# Launch colossalai, setup distributed environment
|
||||
# ==============================
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
|
@ -59,7 +59,7 @@ def infer(args):
|
|||
coordinator.print_on_master(out[0])
|
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 llama_gen.py -m MODEL_PATH
|
||||
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
|
@ -0,0 +1,4 @@
|
|||
#!/bin/bash
|
||||
echo "Skip the test (this test is slow)"
|
||||
|
||||
# bash ./run_benchmark.sh
|
|
@ -35,7 +35,7 @@ from transformers.utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from colossalai.kernel.extensions.pybind.flash_attention import HAS_FLASH_ATTN
|
||||
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.moe.layers import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
ordered_set
|
||||
transformers==4.36.2
|
|
@ -1,6 +1,4 @@
|
|||
diffusers
|
||||
fbgemm-gpu==0.2.0
|
||||
ordered_set
|
||||
pytest
|
||||
coverage==7.2.3
|
||||
git+https://github.com/hpcaitech/pytest-testmon
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@ -20,27 +20,6 @@ def check_config_and_inference():
|
|||
max_output_len=256,
|
||||
)
|
||||
|
||||
sequence2 = Sequence(
|
||||
request_id=2,
|
||||
prompt="bcd",
|
||||
input_token_id=[4, 5, 6],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
sequence3 = Sequence(
|
||||
request_id=3,
|
||||
prompt="efg",
|
||||
input_token_id=[7, 8, 9],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
sequence.mark_running()
|
||||
assert sequence.status == RequestStatus.RUNNING
|
||||
sequence.recycle()
|
||||
|
@ -51,33 +30,6 @@ def check_config_and_inference():
|
|||
assert sequence.output_len == 0
|
||||
assert sequence.check_finish() == False
|
||||
|
||||
batch = BatchInfo(
|
||||
max_batch_size=8,
|
||||
kv_max_split_num=16,
|
||||
num_heads=2,
|
||||
head_dim=128,
|
||||
)
|
||||
batch.add_seqs([sequence])
|
||||
batch.add_seqs([sequence2, sequence3])
|
||||
|
||||
# add duplicated sequence to test that it will not be counted twice
|
||||
batch.add_seqs([sequence])
|
||||
|
||||
assert batch.is_empty == False
|
||||
assert batch.get_batch_size() == 3
|
||||
batch.update_batch_tokens([1, 2, 3])
|
||||
seq = batch.abort_seq(sequence)
|
||||
seq2 = batch.fliter_batch()[0]
|
||||
|
||||
assert batch.get_batch_size() == 1
|
||||
assert seq.output_len == 1
|
||||
assert seq.output_token_id == [1]
|
||||
assert seq2.output_len == 1
|
||||
assert seq2.output_token_id == [2]
|
||||
|
||||
batch.clear_batch()
|
||||
assert batch.is_empty == True
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
|
|
@ -86,7 +86,7 @@ def run_dist(rank, world_size, port):
|
|||
check_output_consistency(128)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_cuda_graph_infer():
|
||||
spawn(run_dist, 1)
|
||||
|
|
|
@ -11,13 +11,16 @@ MAX_LEN = 100
|
|||
SPEC_NUM = 5
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
|
||||
def test_drafter(spec_num: int):
|
||||
def test_drafter(tokenizer, spec_num: int):
|
||||
torch.manual_seed(123)
|
||||
|
||||
device = get_current_device()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||
drafter_model = LlamaForCausalLM(toy_config)
|
||||
|
@ -39,10 +42,9 @@ def test_drafter(spec_num: int):
|
|||
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
||||
|
||||
|
||||
def test_spec_dec():
|
||||
def test_spec_dec(tokenizer):
|
||||
spec_num = SPEC_NUM
|
||||
device = get_current_device()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Dummy config for Glide Model
|
||||
|
@ -67,5 +69,6 @@ def test_spec_dec():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_drafter(spec_num=SPEC_NUM)
|
||||
test_spec_dec()
|
||||
dummy_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
test_drafter(dummy_tokenizer, spec_num=SPEC_NUM)
|
||||
test_spec_dec(dummy_tokenizer)
|
||||
|
|
|
@ -165,8 +165,10 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
|||
func_to_run(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@parameterize("prompt_template", [None, "llama"])
|
||||
@parameterize("do_sample", [False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_tp_engine(prompt_template, do_sample):
|
||||
kwargs1 = {
|
||||
"use_engine": True,
|
||||
|
@ -186,18 +188,14 @@ def test_tp_engine(prompt_template, do_sample):
|
|||
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@parameterize("num_layers", [1])
|
||||
@parameterize("max_length", [64])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_spec_dec(num_layers, max_length):
|
||||
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_inference_engine():
|
||||
if __name__ == "__main__":
|
||||
test_tp_engine()
|
||||
test_spec_dec()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inference_engine()
|
||||
|
|
|
@ -86,11 +86,11 @@ def torch_attn_unpad(
|
|||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||
@pytest.mark.parametrize("bsz", [7, 32])
|
||||
@pytest.mark.parametrize("block_size", [16, 32])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16])
|
||||
@pytest.mark.parametrize("num_attn_heads", [16])
|
||||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
||||
@pytest.mark.parametrize("kv_group_num", [1, 4])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
||||
|
|
|
@ -68,11 +68,11 @@ def prepare_data(
|
|||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||
@pytest.mark.parametrize("bsz", [7, 16])
|
||||
@pytest.mark.parametrize("block_size", [16, 32])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16])
|
||||
@pytest.mark.parametrize("num_attn_heads", [16])
|
||||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
||||
@pytest.mark.parametrize("kv_group_num", [1, 4])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("q_len", [1, 5])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
|
@ -187,7 +187,7 @@ def test_flash_decoding(
|
|||
|
||||
rtol = 1e-4
|
||||
# After the shape becomes larger, some data elements are too small, leading to excessively large relative errors.
|
||||
if bsz == 32 and use_alibi_slopes:
|
||||
if bsz >= 16 and use_alibi_slopes:
|
||||
rtol = 100
|
||||
|
||||
numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol)
|
||||
|
|
|
@ -70,9 +70,9 @@ def prepare_data(
|
|||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
||||
@pytest.mark.parametrize("bsz", [7, 32])
|
||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [16])
|
||||
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("n_tokens", [1, 5])
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import pytest
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
|
@ -7,6 +8,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
|||
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||
def test_copy_to_cache():
|
||||
key = torch.ones((2, 11, 3, 3))
|
||||
key[0, 9, :, :] = 0
|
||||
|
@ -24,6 +26,7 @@ def test_copy_to_cache():
|
|||
assert cache[3, 0, 0, 0] == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||
def test_convert_kvcache():
|
||||
cache = torch.ones(8, 3, 8, 3)
|
||||
key = torch.ones(2, 1, 3, 3) + 1
|
||||
|
@ -34,6 +37,7 @@ def test_convert_kvcache():
|
|||
assert converted_cache.shape == (2, 10, 3, 3)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||
def test_context_attention():
|
||||
"""
|
||||
test config: head_num = 4, head_size = 4
|
||||
|
@ -86,6 +90,7 @@ def test_context_attention():
|
|||
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is not used in the current version.")
|
||||
def test_decoding_attention():
|
||||
# test the pipeline of decoding attention
|
||||
attn = PagedAttention()
|
||||
|
|
|
@ -128,7 +128,7 @@ def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
|||
not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH),
|
||||
reason="There is no local model address included, please replace this address with a valid one.",
|
||||
)
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_inference_engine():
|
||||
check_tp_engine()
|
||||
|
|
Loading…
Reference in New Issue