ColossalAI/examples/inference/llama/llama_generation.py

106 lines
4.6 KiB
Python

import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.policy.nopadding_llama import NoPaddingLlamaModelInferPolicy
# For Llama 3, we'll use the following configuration
MODEL_CLS = AutoModelForCausalLM
POLICY_CLS = NoPaddingLlamaModelInferPolicy
def infer(args):
# ==============================
# Launch colossalai, setup distributed environment
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================
# Load model and tokenizer
# ==============================
model_path_or_name = args.model
model = MODEL_CLS.from_pretrained(model_path_or_name)
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
tokenizer.pad_token = tokenizer.eos_token
# coordinator.print_on_master(f"Model Config:\n{model.config}")
# ==============================
# Initialize InferenceEngine
# ==============================
inference_config = InferenceConfig(
dtype=args.dtype,
max_batch_size=args.max_batch_size,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
prefill_ratio=1.2,
block_size=16,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
)
coordinator.print_on_master(f"Initializing Inference Engine...")
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
# ==============================
# Generation
# ==============================
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_length=args.max_length,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
coordinator.print_on_master(out)
# ==============================
# Optionally, load drafter model and proceed speculative decoding
# ==============================
drafter_model_path_or_name = args.drafter_model
if drafter_model_path_or_name is not None:
drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name)
# turn on speculative decoding with the drafter model
engine.enable_spec_dec(drafter_model)
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
coordinator.print_on_master(out)
engine.disable_spec_dec()
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
# colossalai run --nproc_per_node 2 llama_generation.py -m MODEL_PATH --tp_size 2
if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
parser.add_argument("--drafter_model", type=str, help="Path to the drafter model or model name")
parser.add_argument(
"-p", "--prompt", type=str, default="Introduce some landmarks in the United Kingdom, such as", help="Prompt"
)
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
parser.add_argument("-i", "--max_input_len", type=int, default=128, help="Max input length")
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="Max output length")
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
# Generation configs
parser.add_argument("--max_length", type=int, default=64, help="Max length for generation")
parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
args = parser.parse_args()
infer(args)