Browse Source

[example] Update Inference Example (#5725)

* [example] update inference example
pull/5730/head
Yuanheng Zhao 6 months ago committed by GitHub
parent
commit
8bcfe360fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 96
      colossalai/inference/spec/README.md
  2. 47
      examples/inference/llama/README.md
  3. 32
      examples/inference/llama/llama_generation.py

96
colossalai/inference/spec/README.md

@ -1,96 +0,0 @@
# Speculative Decoding
Colossal-Inference supports speculative decoding using the inference engine, with optimized kernels and cache management for the main model.
Both a drafter model (small model) and a main model (large model) will be used during speculative decoding process. The drafter model will generate a few tokens sequentially, and then the main model will validate those candidate tokens in parallel and accept validated ones. The decoding process will be speeded up, for the latency of speculating multiple tokens by the drafter model is lower than that by the main model.
Moreover, Colossal-Inference also supports GLIDE, a modified draft model architecture that reuses key and value caches from the main model, which improves the acceptance rate and increment the speed-up ratio. Details can be found in research paper GLIDE with a CAPE - A Low-Hassle Method to Accelerate Speculative Decoding on [arXiv](https://arxiv.org/pdf/2402.02082.pdf).
Right now, Colossal-Inference offers a GLIDE model compatible with vicuna7B. You can find the fine-tuned GLIDE drafter model `cxdu/glide47m-vicuna7b` on the HuggingFace Hub: https://huggingface.co/cxdu/glide47m-vicuna7b.
## Usage
For main model, you might want to use model card `lmsys/vicuna-7b-v1.5` at [HuggingFace Hub](https://huggingface.co/lmsys/vicuna-7b-v1.5).
For regular drafter model, you might want to use model card `JackFram/llama-68m` at [HuggingFace Hub](https://huggingface.co/JackFram/llama-68m).
For the GLIDE drafter model, you could use model card `cxdu/glide47m-vicuna7b` at [HuggingFace Hub](https://huggingface.co/cxdu/glide47m-vicuna7b).
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine, GenerationConfig
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig
# launch colossalai, setup distributed environment
colossalai.launch_from_torch()
# main model
model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD"
model = AutoModelForCausalLM.from_pretrained(model_path_or_name)
# use the same tokenizer for both the main model and the drafter model
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
tokenizer.pad_token = tokenizer.eos_token
# drafter model
drafter_model_path_or_name = "REPLACE_TO_LLAMA_68M_PATH_OR_MODEL_CARD"
drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name)
# Initialize the inference engine
inference_config = InferenceConfig(
dtype="fp16",
max_batch_size=1,
max_input_len=256,
max_output_len=256,
prefill_ratio=1.2,
block_size=16,
max_n_spec_tokens=5,
prompt_template="vicuna",
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
# turn on speculative decoding with the drafter model
engine.enable_spec_dec(drafter_model)
prompt = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions. "
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_length=128,
num_beams=1,
do_sample=False,
)
out = engine.generate(prompts=[prompt], generation_config=generation_config)
print(out)
# use GLIDE Llama model as drafter model
drafter_model_path_or_name = "cxdu/glide47m-vicuna7b"
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=1,
)
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name, config=glide_config)
# turn on speculative decoding with the GLIDE model
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
out = engine.generate(prompts=[prompt], generation_config=generation_config)
print(out)
```
You could run the above code by
```bash
colossalai run --nproc_per_node 1 script_name.py
```
## Benchmark
With batch size 1, testing with gsm8k and MT-Bench dataset on NVIDIA H800 80G:
| Method | Tokens/Sec |
| :--------------------------- | :--------- |
| Non-Spec-Dec | ~90 |
| Spec-Dec | ~115 |
| Spec-Dec with GLIDE Model | ~135 |

47
examples/inference/llama/README.md

@ -0,0 +1,47 @@
## Run Inference
The provided example `llama_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `AutoModelForCausalLM` and `NoPaddingLlamaModelInferPolicy` as model class and policy class, and the script is good to run inference with Llama 3.
For a basic setting, you could run the example by:
```bash
colossalai run --nproc_per_node 1 llama_generation.py -m PATH_MODEL --max_length 128
```
Run multi-GPU inference (Tensor Parallelism), as in the following example using 2 GPUs:
```bash
colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --max_length 128 --tp_size 2
```
## Run Speculative Decoding
Colossal-Inference supports speculative decoding using the inference engine, with optimized kernels and cache management for the main model.
Both a drafter model (small model) and a main model (large model) will be used during speculative decoding process. The drafter model will generate a few tokens sequentially, and then the main model will validate those candidate tokens in parallel and accept validated ones. The decoding process will be speeded up, for the latency of speculating multiple tokens by the drafter model is lower than that by the main model.
Moreover, Colossal-Inference also supports GLIDE, a modified draft model architecture that reuses key and value caches from the main model, which improves the acceptance rate and increment the speed-up ratio. Details can be found in research paper GLIDE with a CAPE - A Low-Hassle Method to Accelerate Speculative Decoding on [arXiv](https://arxiv.org/pdf/2402.02082.pdf).
Right now, Colossal-Inference offers a GLIDE model compatible with vicuna7B (https://huggingface.co/lmsys/vicuna-7b-v1.5). You can find the fine-tuned GLIDE drafter model `cxdu/glide-vicuna7b` on the HuggingFace Hub: https://huggingface.co/cxdu/glide-vicuna7b.
Benchmarking with gsm8k and MT-Bench dataset with batch size 1 on H800, the speed increase for using speculative decoding is around 1.28x, and the speed increase for using speculative decoding with Glide model (as drafter model) is around 1.5x.
## Usage
For main model, you might want to use model card `lmsys/vicuna-7b-v1.5` at [HuggingFace Hub](https://huggingface.co/lmsys/vicuna-7b-v1.5).
For regular drafter model, you might want to use model card `JackFram/llama-68m` at [HuggingFace Hub](https://huggingface.co/JackFram/llama-68m).
For the GLIDE drafter model, you could use model card `cxdu/glide-vicuna7b` at [HuggingFace Hub](https://huggingface.co/cxdu/glide-vicuna7b).
You could run speculative decoding by
```bash
colossalai run --nproc_per_node 1 llama_generation.py -m PATH_MODEL --drafter_model PATH_DRAFTER_MODEL --max_length 128
```
Run multi-GPU inference (Tensor Parallelism), as in the following example using 2 GPUs.
```bash
colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_model PATH_DRAFTER_MODEL --max_length 128 --tp_size 2
```
If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
```python
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
```

32
examples/inference/llama/llama_generation.py

@ -27,7 +27,7 @@ def infer(args):
model = MODEL_CLS.from_pretrained(model_path_or_name)
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
tokenizer.pad_token = tokenizer.eos_token
coordinator.print_on_master(f"Model Config:\n{model.config}")
# coordinator.print_on_master(f"Model Config:\n{model.config}")
# ==============================
# Initialize InferenceEngine
@ -52,20 +52,39 @@ def infer(args):
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_length=args.max_length,
do_sample=True,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
coordinator.print_on_master(out[0])
coordinator.print_on_master(out)
# ==============================
# Optionally, load drafter model and proceed speculative decoding
# ==============================
drafter_model_path_or_name = args.drafter_model
if drafter_model_path_or_name is not None:
drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name)
# turn on speculative decoding with the drafter model
engine.enable_spec_dec(drafter_model)
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
coordinator.print_on_master(out)
engine.disable_spec_dec()
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
# colossalai run --nproc_per_node 2 llama_generation.py -m MODEL_PATH --tp_size 2
if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
parser.add_argument("--drafter_model", type=str, help="Path to the drafter model or model name")
parser.add_argument(
"-p", "--prompt", type=str, default="Introduce some landmarks in the United Kingdom, such as", help="Prompt"
)
@ -75,7 +94,12 @@ if __name__ == "__main__":
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
parser.add_argument("--max_length", type=int, default=32, help="Max length for generation")
# Generation configs
parser.add_argument("--max_length", type=int, default=64, help="Max length for generation")
parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
args = parser.parse_args()
infer(args)

Loading…
Cancel
Save