import argparse
from torch import bfloat16 , float16 , float32
from transformers import AutoModelForCausalLM , AutoTokenizer , GenerationConfig
import colossalai
from colossalai . cluster import DistCoordinator
from colossalai . inference . config import InferenceConfig
from colossalai . inference . core . engine import InferenceEngine
from colossalai . inference . modeling . policy . nopadding_llama import NoPaddingLlamaModelInferPolicy
# For Llama 3, we'll use the following configuration
MODEL_CLS = AutoModelForCausalLM
POLICY_CLS = NoPaddingLlamaModelInferPolicy
TORCH_DTYPE_MAP = {
" fp16 " : float16 ,
" fp32 " : float32 ,
" bf16 " : bfloat16 ,
}
def infer ( args ) :
# ==============================
# Launch colossalai, setup distributed environment
# ==============================
colossalai . launch_from_torch ( )
coordinator = DistCoordinator ( )
# ==============================
# Load model and tokenizer
# ==============================
model_path_or_name = args . model
model = MODEL_CLS . from_pretrained ( model_path_or_name , torch_dtype = TORCH_DTYPE_MAP . get ( args . dtype , None ) )
tokenizer = AutoTokenizer . from_pretrained ( model_path_or_name )
tokenizer . pad_token = tokenizer . eos_token
# coordinator.print_on_master(f"Model Config:\n{model.config}")
# ==============================
# Initialize InferenceEngine
# ==============================
inference_config = InferenceConfig (
dtype = args . dtype ,
max_batch_size = args . max_batch_size ,
max_input_len = args . max_input_len ,
max_output_len = args . max_output_len ,
prefill_ratio = 1.2 ,
block_size = 16 ,
tp_size = args . tp_size ,
use_cuda_kernel = args . use_cuda_kernel ,
enable_streamingllm = args . enable_streamingllm ,
start_token_size = args . start_token_size ,
generated_token_size = args . generated_token_size ,
)
coordinator . print_on_master ( f " Initializing Inference Engine... " )
engine = InferenceEngine ( model , tokenizer , inference_config , model_policy = POLICY_CLS ( ) , verbose = True )
# ==============================
# Generation
# ==============================
generation_config = GenerationConfig (
pad_token_id = tokenizer . eos_token_id ,
eos_token_id = tokenizer . eos_token_id ,
max_length = args . max_length ,
do_sample = args . do_sample ,
temperature = args . temperature ,
top_k = args . top_k ,
top_p = args . top_p ,
no_repeat_ngram_size = args . no_repeat_ngram_size ,
repetition_penalty = args . repetition_penalty ,
)
coordinator . print_on_master ( f " Generating... " )
out = engine . generate ( prompts = [ args . prompt ] , generation_config = generation_config )
coordinator . print_on_master ( out )
# ==============================
# Optionally, load drafter model and proceed speculative decoding
# ==============================
drafter_model_path_or_name = args . drafter_model
if drafter_model_path_or_name is not None :
drafter_model = AutoModelForCausalLM . from_pretrained ( drafter_model_path_or_name )
# turn on speculative decoding with the drafter model
engine . enable_spec_dec ( drafter_model )
coordinator . print_on_master ( f " Generating... " )
out = engine . generate ( prompts = [ args . prompt ] , generation_config = generation_config )
coordinator . print_on_master ( out )
engine . disable_spec_dec ( )
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
# colossalai run --nproc_per_node 2 llama_generation.py -m MODEL_PATH --tp_size 2
if __name__ == " __main__ " :
# ==============================
# Parse Arguments
# ==============================
parser = argparse . ArgumentParser ( )
parser . add_argument ( " -m " , " --model " , type = str , help = " Path to the model or model name " )
parser . add_argument ( " --drafter_model " , type = str , help = " Path to the drafter model or model name " )
parser . add_argument (
" -p " , " --prompt " , type = str , default = " Introduce some landmarks in the United Kingdom, such as " , help = " Prompt "
)
parser . add_argument ( " -b " , " --max_batch_size " , type = int , default = 1 , help = " Max batch size " )
parser . add_argument ( " -i " , " --max_input_len " , type = int , default = 128 , help = " Max input length " )
parser . add_argument ( " -o " , " --max_output_len " , type = int , default = 128 , help = " Max output length " )
parser . add_argument ( " -t " , " --tp_size " , type = int , default = 1 , help = " Tensor Parallelism size " )
parser . add_argument ( " -d " , " --dtype " , type = str , default = " fp16 " , help = " Data type " , choices = [ " fp16 " , " fp32 " , " bf16 " ] )
parser . add_argument ( " --use_cuda_kernel " , action = " store_true " , help = " Use CUDA kernel, use Triton by default " )
# Generation configs
parser . add_argument ( " --max_length " , type = int , default = 64 , help = " Max length for generation " )
parser . add_argument ( " --do_sample " , action = " store_true " , help = " Use sampling for generation " )
parser . add_argument ( " --temperature " , type = float , default = 1.0 , help = " Temperature for generation " )
parser . add_argument ( " --top_k " , type = int , default = 50 , help = " Top k for generation " )
parser . add_argument ( " --top_p " , type = float , default = 1.0 , help = " Top p for generation " )
parser . add_argument ( " --enable_streamingllm " , action = " store_true " , help = " Whether to use StreamingLLM " )
parser . add_argument (
" --start_token_size " , type = int , default = 4 , help = " The size of the start_token, When using StreamingLLM, "
)
parser . add_argument (
" --generated_token_size " , type = int , default = 512 , help = " The size of the generated_token, When using StreamingLLM "
)
parser . add_argument (
" --no_repeat_ngram_size " ,
type = int ,
default = 0 ,
help = " If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. " ,
)
parser . add_argument (
" --repetition_penalty " ,
type = float ,
default = 1.0 ,
help = " The parameter that influences the model ' s treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. " ,
)
args = parser . parse_args ( )
infer ( args )