ColossalAI/colossalai/inference/spec
Yuanheng Zhao 55cc7f3df7
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements

* modify example inference struct

* add test ci scripts

* mark test_infer as submodule

* rm deprecated cls & deps

* import of HAS_FLASH_ATTN

* prune inference tests to be run

* prune triton kernel tests

* increment pytest timeout mins

* revert import path in openmoe
2024-05-08 11:30:15 +08:00
..
README.md [Fix] Fix Inference Example, Tests, and Requirements (#5688) 2024-05-08 11:30:15 +08:00
__init__.py [Inference/SpecDec] Support GLIDE Drafter Model (#5455) 2024-04-10 11:07:52 +08:00
drafter.py [Inference/SpecDec] Support GLIDE Drafter Model (#5455) 2024-04-10 11:07:52 +08:00
struct.py [Inference/SpecDec] Support GLIDE Drafter Model (#5455) 2024-04-10 11:07:52 +08:00

README.md

Speculative Decoding

Colossal-Inference supports speculative decoding using the inference engine, with optimized kernels and cache management for the main model.

Both a drafter model (small model) and a main model (large model) will be used during speculative decoding process. The drafter model will generate a few tokens sequentially, and then the main model will validate those candidate tokens in parallel and accept validated ones. The decoding process will be speeded up, for the latency of speculating multiple tokens by the drafter model is lower than that by the main model.

Moreover, Colossal-Inference also supports GLIDE, a modified draft model architecture that reuses key and value caches from the main model, which improves the acceptance rate and increment the speed-up ratio. Details can be found in research paper GLIDE with a CAPE - A Low-Hassle Method to Accelerate Speculative Decoding on arXiv.

Right now, Colossal-Inference offers a GLIDE model compatible with vicuna7B. You can find the fine-tuned GLIDE drafter model cxdu/glide47m-vicuna7b on the HuggingFace Hub: https://huggingface.co/cxdu/glide47m-vicuna7b.

Usage

For main model, you might want to use model card lmsys/vicuna-7b-v1.5 at HuggingFace Hub. For regular drafter model, you might want to use model card JackFram/llama-68m at HuggingFace Hub. For the GLIDE drafter model, you could use model card cxdu/glide47m-vicuna7b at HuggingFace Hub.

from transformers import AutoTokenizer, AutoModelForCausalLM

import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine, GenerationConfig
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig

# launch colossalai, setup distributed environment
colossalai.launch_from_torch()

# main model
model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD"
model = AutoModelForCausalLM.from_pretrained(model_path_or_name)

# use the same tokenizer for both the main model and the drafter model
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
tokenizer.pad_token = tokenizer.eos_token

# drafter model
drafter_model_path_or_name = "REPLACE_TO_LLAMA_68M_PATH_OR_MODEL_CARD"
drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name)

# Initialize the inference engine
inference_config = InferenceConfig(
    dtype="fp16",
    max_batch_size=1,
    max_input_len=256,
    max_output_len=256,
    prefill_ratio=1.2,
    block_size=16,
    max_n_spec_tokens=5,
    prompt_template="vicuna",
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)

# turn on speculative decoding with the drafter model
engine.enable_spec_dec(drafter_model)

prompt = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions. "
generation_config = GenerationConfig(
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    max_length=128,
    num_beams=1,
    do_sample=False,
)
out = engine.generate(prompts=[prompt], generation_config=generation_config)
print(out)

# use GLIDE Llama model as drafter model
drafter_model_path_or_name = "cxdu/glide47m-vicuna7b"
glide_config = GlideLlamaConfig(
    intermediate_size=8192,
    large_hidden_size=4096,
    large_num_attention_heads=32,
    num_hidden_layers=1,
)
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name, config=glide_config)

# turn on speculative decoding with the GLIDE model
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
out = engine.generate(prompts=[prompt], generation_config=generation_config)
print(out)

You could run the above code by

colossalai run --nproc_per_node 1 script_name.py

Benchmark

With batch size 1, testing with gsm8k and MT-Bench dataset on NVIDIA H800 80G:

Method Tokens/Sec
Non-Spec-Dec ~90
Spec-Dec ~115
Spec-Dec with GLIDE Model ~135