You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_infer/test_drafter.py

109 lines
3.5 KiB

import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import GenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.spec.drafter import Drafter
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
NUM_LAYERS = 2
MAX_LEN = 100
@pytest.mark.parametrize("spec_num", [5])
def test_drafter(spec_num: int):
torch.manual_seed(123)
device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
drafter = Drafter(drafter_model, tokenizer, device=device)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num)
past_kv_length = input_ids.size(1) + spec_num - 1
assert out.speculated_length == spec_num
assert out.next_tokens.shape == (spec_num,)
assert out.logits.shape == (spec_num, len(tokenizer))
assert out.past_key_values[0][0].size(2) == past_kv_length
reject_num = max(0, spec_num - 1)
trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num)
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
def check_sd():
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=MAX_LEN,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_sd()
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_spec_dec():
spawn(run_dist, nprocs=1)
if __name__ == "__main__":
test_drafter(spec_num=5)
test_spec_dec()