mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
3.4 KiB
122 lines
3.4 KiB
import random |
|
|
|
import numpy as np |
|
import torch |
|
from torch.multiprocessing import Manager |
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM |
|
|
|
import colossalai |
|
from colossalai.inference.config import InferenceConfig |
|
from colossalai.inference.core.engine import InferenceEngine |
|
from colossalai.testing import rerun_if_address_is_in_use, spawn |
|
|
|
|
|
def data_gen(batch_size: int = 4, seq_len: int = 512): |
|
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=torch.cuda.current_device()) |
|
return input_ids |
|
|
|
|
|
def setup_seed(seed): |
|
torch.manual_seed(seed) |
|
torch.random.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
|
|
def check_streamingllm(): |
|
setup_seed(20) |
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") |
|
model = LlamaForCausalLM( |
|
LlamaConfig( |
|
vocab_size=50000, |
|
hidden_size=512, |
|
intermediate_size=1536, |
|
num_attention_heads=4, |
|
num_key_value_heads=2, |
|
num_hidden_layers=16, |
|
) |
|
).cuda() |
|
model = model.eval() |
|
|
|
input_token_ids = data_gen(1, 4) |
|
|
|
output_len = 128 |
|
|
|
inference_config = InferenceConfig( |
|
max_batch_size=1, |
|
max_output_len=output_len, |
|
dtype="fp32", |
|
use_cuda_kernel=True, |
|
enable_streamingllm=True, |
|
start_token_size=4, |
|
generated_token_size=32, |
|
) |
|
|
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) |
|
assert inference_engine.generation_config.max_new_tokens == output_len |
|
inference_engine.add_request(prompts_token_ids=input_token_ids) |
|
assert inference_engine.request_handler._has_waiting() |
|
|
|
assert inference_config.start_token_size == inference_config.block_size |
|
|
|
request_handler = inference_engine.request_handler |
|
running_bb = request_handler.running_bb |
|
|
|
for _ in range(12): |
|
inference_engine.step() |
|
|
|
assert running_bb.block_tables[0].tolist() == [0, -1, -1, -1] |
|
assert running_bb.seq_lengths[0].item() == 16 |
|
|
|
for _ in range(16): |
|
inference_engine.step() |
|
|
|
assert running_bb.block_tables[0].tolist() == [0, 1, -1, -1] |
|
assert running_bb.seq_lengths[0].item() == 32 |
|
|
|
for _ in range(16): |
|
inference_engine.step() |
|
|
|
assert running_bb.block_tables[0].tolist() == [0, 1, 2, -1] |
|
assert running_bb.seq_lengths[0].item() == 48 |
|
|
|
for _ in range(16): |
|
inference_engine.step() |
|
|
|
assert running_bb.block_tables[0].tolist() == [0, 2, 3, -1] |
|
assert running_bb.seq_lengths[0].item() == 48 |
|
|
|
for _ in range(1): |
|
inference_engine.step() |
|
|
|
assert running_bb.block_tables[0].tolist() == [0, 2, 3, 1] |
|
assert running_bb.seq_lengths[0].item() == 49 |
|
|
|
for _ in range(15): |
|
inference_engine.step() |
|
|
|
assert running_bb.block_tables[0].tolist() == [0, 3, 1, -1] |
|
assert running_bb.seq_lengths[0].item() == 48 |
|
|
|
|
|
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): |
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") |
|
|
|
if ret: |
|
ret[rank] = func_to_run(**kwargs) |
|
else: |
|
func_to_run(**kwargs) |
|
|
|
|
|
@rerun_if_address_is_in_use() |
|
def test_engine(): |
|
manager = Manager() |
|
result_list = manager.list([-1] * 1) # Create a shared list |
|
|
|
spawn(run_dist, 1, func_to_run=check_streamingllm, ret=result_list) |
|
return result_list[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
test_engine()
|
|
|