mirror of https://github.com/hpcaitech/ColossalAI
103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
from itertools import accumulate
|
|
|
|
import pytest
|
|
import torch
|
|
from packaging import version
|
|
from transformers import BloomConfig, BloomForCausalLM
|
|
from transformers.tokenization_utils_base import BatchEncoding
|
|
|
|
import colossalai
|
|
from colossalai.inference.tensor_parallel import TPInferEngine
|
|
from colossalai.logging import disable_existing_loggers
|
|
from colossalai.shardformer import ShardConfig
|
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
|
|
|
TP_SIZE = 2
|
|
MAX_BATCH_SIZE = 4
|
|
MAX_INPUT_LEN = 16
|
|
MAX_OUTPUT_LEN = 8
|
|
|
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
|
|
|
|
|
@parameterize(
|
|
"test_config",
|
|
[
|
|
{
|
|
"tp_size": TP_SIZE,
|
|
}
|
|
],
|
|
)
|
|
def run(test_config):
|
|
model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
|
|
model = BloomForCausalLM(model_config)
|
|
model = model.half()
|
|
model.to(torch.cuda.current_device())
|
|
|
|
# 1. check TPInferEngine init and model optimization
|
|
shard_config = ShardConfig(
|
|
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
|
)
|
|
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
|
|
|
assert infer_engine.cache_manager is not None
|
|
assert infer_engine.tp_size == TP_SIZE
|
|
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE
|
|
|
|
# 2. check data preparation
|
|
input_ids_list = [
|
|
[80540, 15473, 3331, 11970, 90472, 361, 61335],
|
|
[80540, 15473, 3331, 11970],
|
|
[80540, 15473, 3331, 11970],
|
|
[80540, 15473],
|
|
]
|
|
batch_size = len(input_ids_list)
|
|
max_seq_len = max(len(li) for li in input_ids_list)
|
|
attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
|
|
for i, li in enumerate(input_ids_list):
|
|
attention_mask[i][max_seq_len - len(li) :] = [1 for _ in range(len(li))]
|
|
data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
|
|
inputs_batch_encoding = BatchEncoding(data=data)
|
|
seq_lengths = [len(li) for li in input_ids_list]
|
|
start_loc = list(accumulate([0] + seq_lengths[:-1]))
|
|
seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
|
|
start_loc = torch.tensor(start_loc, dtype=torch.int32)
|
|
# input token id list as inputs
|
|
batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
|
|
# BatchEncoding as inputs
|
|
batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)
|
|
|
|
assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
|
|
assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len)
|
|
|
|
# The following tests are discarded for now, and will be reused after all features are added
|
|
# assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
|
|
# assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
|
|
# assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
|
|
# assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)
|
|
|
|
# 3. check optimized model generate
|
|
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))
|
|
generate_kwargs = dict(do_sample=False)
|
|
infer_engine.generate(input_ids, **generate_kwargs)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def check_engine(rank, world_size, port):
|
|
disable_existing_loggers()
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
run()
|
|
|
|
|
|
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
@clear_cache_before_run()
|
|
def test_engine():
|
|
spawn(check_engine, TP_SIZE)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_engine()
|