|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.inference.config import GenerationConfig, InferenceConfig
|
|
|
|
from colossalai.inference.core.engine import InferenceEngine
|
|
|
|
from colossalai.inference.spec.drafter import Drafter
|
|
|
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
|
|
|
|
NUM_LAYERS = 2
|
|
|
|
MAX_LEN = 100
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("spec_num", [5])
|
|
|
|
def test_drafter(spec_num: int):
|
|
|
|
torch.manual_seed(123)
|
|
|
|
|
|
|
|
device = get_current_device()
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
|
|
|
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
|
|
|
toy_config.pad_token_id = tokenizer.eos_token_id
|
|
|
|
drafter_model = LlamaForCausalLM(toy_config)
|
|
|
|
drafter_model = drafter_model.eval().cuda()
|
|
|
|
|
|
|
|
drafter = Drafter(drafter_model, tokenizer, device=device)
|
|
|
|
|
|
|
|
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
|
|
|
|
out = drafter.speculate(input_ids, spec_num)
|
|
|
|
past_kv_length = input_ids.size(1) + spec_num - 1
|
|
|
|
|
|
|
|
assert out.speculated_length == spec_num
|
|
|
|
assert out.next_tokens.shape == (spec_num,)
|
|
|
|
assert out.logits.shape == (spec_num, len(tokenizer))
|
|
|
|
assert out.past_key_values[0][0].size(2) == past_kv_length
|
|
|
|
|
|
|
|
reject_num = max(0, spec_num - 1)
|
|
|
|
trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num)
|
|
|
|
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
|
|
|
|
|
|
|
|
|
|
|
def check_sd():
|
|
|
|
torch.manual_seed(123)
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
|
|
|
# Dummy configs for testing
|
|
|
|
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
|
|
|
toy_config.pad_token_id = tokenizer.eos_token_id
|
|
|
|
drafter_model = LlamaForCausalLM(toy_config)
|
|
|
|
drafter_model = drafter_model.eval().cuda()
|
|
|
|
large_config = LlamaConfig(
|
|
|
|
hidden_size=4096,
|
|
|
|
intermediate_size=11008,
|
|
|
|
num_attention_heads=32,
|
|
|
|
num_hidden_layers=8,
|
|
|
|
num_key_value_heads=32,
|
|
|
|
max_position_embeddings=2048,
|
|
|
|
)
|
|
|
|
large_config.pad_token_id = tokenizer.eos_token_id
|
|
|
|
main_model = LlamaForCausalLM(large_config)
|
|
|
|
|
|
|
|
inference_config = InferenceConfig(
|
|
|
|
dtype="fp16",
|
|
|
|
micro_batch_size=1,
|
|
|
|
max_batch_size=1,
|
|
|
|
max_input_len=128,
|
|
|
|
max_output_len=128,
|
|
|
|
prefill_ratio=1.2,
|
|
|
|
block_size=16,
|
|
|
|
)
|
|
|
|
engine = InferenceEngine(main_model, tokenizer, inference_config)
|
|
|
|
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
|
|
|
|
|
|
|
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
|
|
|
|
generation_config = GenerationConfig(
|
|
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
|
|
max_length=MAX_LEN,
|
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
)
|
|
|
|
out, out_token_ids = engine.generate(
|
|
|
|
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
|
|
|
)
|
|
|
|
engine.disable_spec_dec()
|
|
|
|
engine.clear_spec_dec()
|
|
|
|
|
|
|
|
assert not engine.use_spec_dec
|
|
|
|
assert engine.drafter is None and engine.drafter_model is None
|
|
|
|
|
|
|
|
assert len(out) == 1
|
|
|
|
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
|
|
|
|
|
|
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
|
|
|
check_sd()
|
|
|
|
|
|
|
|
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
@clear_cache_before_run()
|
|
|
|
def test_spec_dec():
|
|
|
|
spawn(run_dist, nprocs=1)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
test_drafter(spec_num=5)
|
|
|
|
test_spec_dec()
|