2023-03-28 12:25:36 +00:00
# Adapted from https://github.com/tloen/alpaca-lora/blob/main/generate.py
import argparse
from time import time
import torch
2023-08-02 02:17:36 +00:00
from coati . quant import llama_load_quant , low_resource_init
from transformers import AutoTokenizer , GenerationConfig , LlamaConfig , LlamaForCausalLM
2023-03-28 12:25:36 +00:00
def generate_prompt ( instruction , input = None ) :
if input :
return f """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{ instruction }
### Input:
{ input }
### Response:"""
else :
return f """ Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{ instruction }
### Response:"""
@torch.no_grad ( )
def evaluate (
model ,
tokenizer ,
instruction ,
input = None ,
temperature = 0.1 ,
top_p = 0.75 ,
top_k = 40 ,
num_beams = 4 ,
max_new_tokens = 128 ,
* * kwargs ,
) :
prompt = generate_prompt ( instruction , input )
inputs = tokenizer ( prompt , return_tensors = " pt " )
input_ids = inputs [ " input_ids " ] . cuda ( )
generation_config = GenerationConfig (
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
num_beams = num_beams ,
* * kwargs ,
)
generation_output = model . generate (
input_ids = input_ids ,
generation_config = generation_config ,
return_dict_in_generate = True ,
output_scores = True ,
max_new_tokens = max_new_tokens ,
do_sample = True ,
)
s = generation_output . sequences [ 0 ]
output = tokenizer . decode ( s )
n_new_tokens = s . size ( 0 ) - input_ids . size ( 1 )
return output . split ( " ### Response: " ) [ 1 ] . strip ( ) , n_new_tokens
instructions = [
" Tell me about alpacas. " ,
" Tell me about the president of Mexico in 2019. " ,
" Tell me about the king of France in 2019. " ,
" List all Canadian provinces in alphabetical order. " ,
" Write a Python program that prints the first 10 Fibonacci numbers. " ,
" Write a program that prints the numbers from 1 to 100. But for multiples of three print ' Fizz ' instead of the number and for the multiples of five print ' Buzz ' . For numbers which are multiples of both three and five print ' FizzBuzz ' . " ,
" Tell me five words that rhyme with ' shock ' . " ,
" Translate the sentence ' I have no mouth but I must scream ' into Spanish. " ,
" Count up from 1 to 500. " ,
# ===
" How to play support in legends of league " ,
" Write a Python program that calculate Fibonacci numbers. " ,
]
inst = [ instructions [ 0 ] ] * 4
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
parser . add_argument (
2023-09-19 06:20:26 +00:00
" pretrained " ,
help = " Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub. " ,
)
parser . add_argument (
" --quant " ,
choices = [ " 8bit " , " 4bit " ] ,
default = None ,
help = " Quantization mode. Default: None (no quantization, fp16). " ,
)
2023-03-28 12:25:36 +00:00
parser . add_argument (
2023-09-19 06:20:26 +00:00
" --gptq_checkpoint " ,
2023-03-28 12:25:36 +00:00
default = None ,
2023-09-19 06:20:26 +00:00
help = " Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None. " ,
)
parser . add_argument (
" --gptq_group_size " ,
type = int ,
default = 128 ,
help = " Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128. " ,
)
2023-03-28 12:25:36 +00:00
args = parser . parse_args ( )
2023-09-19 06:20:26 +00:00
if args . quant == " 4bit " :
assert args . gptq_checkpoint is not None , " Please specify a GPTQ checkpoint. "
2023-03-28 12:25:36 +00:00
tokenizer = AutoTokenizer . from_pretrained ( args . pretrained )
2023-09-19 06:20:26 +00:00
if args . quant == " 4bit " :
2023-08-02 02:17:36 +00:00
with low_resource_init ( ) :
config = LlamaConfig . from_pretrained ( args . pretrained )
model = LlamaForCausalLM ( config )
model = llama_load_quant ( model , args . gptq_checkpoint , 4 , args . gptq_group_size )
2023-03-28 12:25:36 +00:00
model . cuda ( )
else :
model = LlamaForCausalLM . from_pretrained (
args . pretrained ,
2023-09-19 06:20:26 +00:00
load_in_8bit = ( args . quant == " 8bit " ) ,
2023-03-28 12:25:36 +00:00
torch_dtype = torch . float16 ,
device_map = " auto " ,
)
2023-09-19 06:20:26 +00:00
if args . quant != " 8bit " :
model . half ( ) # seems to fix bugs for some users.
2023-03-28 12:25:36 +00:00
model . eval ( )
total_tokens = 0
start = time ( )
for instruction in instructions :
print ( f " Instruction: { instruction } " )
2023-05-15 03:46:25 +00:00
resp , tokens = evaluate ( model , tokenizer , instruction , temperature = 0.2 , num_beams = 1 )
2023-03-28 12:25:36 +00:00
total_tokens + = tokens
print ( f " Response: { resp } " )
2023-09-19 06:20:26 +00:00
print ( " \n ---------------------------- \n " )
2023-03-28 12:25:36 +00:00
duration = time ( ) - start
2023-09-19 06:20:26 +00:00
print ( f " Total time: { duration : .3f } s, { total_tokens / duration : .3f } tokens/s " )
print ( f " Peak CUDA mem: { torch . cuda . max_memory_allocated ( ) / 1024 * * 3 : .3f } GB " )