mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.5 KiB
47 lines
1.5 KiB
import argparse
|
|
|
|
import torch
|
|
|
|
|
|
class Bcolors:
|
|
HEADER = "\033[95m"
|
|
OKBLUE = "\033[94m"
|
|
OKCYAN = "\033[96m"
|
|
OKGREEN = "\033[92m"
|
|
WARNING = "\033[93m"
|
|
FAIL = "\033[91m"
|
|
ENDC = "\033[0m"
|
|
BOLD = "\033[1m"
|
|
UNDERLINE = "\033[4m"
|
|
|
|
|
|
def print_output(text, output):
|
|
print(f"-----\n{Bcolors.OKBLUE}{text}{Bcolors.ENDC}{output[len(text):]}")
|
|
|
|
|
|
@torch.no_grad()
|
|
def inference(model, sp, text, **generate_kwargs):
|
|
input_ids = sp.encode(text)
|
|
input_ids = torch.tensor([input_ids]).cuda()
|
|
attention_mask = torch.ones_like(input_ids)
|
|
inputs = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
**generate_kwargs,
|
|
}
|
|
outputs = model.generate(**inputs)
|
|
return outputs[0].tolist()
|
|
|
|
|
|
def get_defualt_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1")
|
|
parser.add_argument("--tokenizer", type=str, default="tokenizer.model")
|
|
parser.add_argument("--text", type=str, nargs="+", default=["Hi, what's your name?"])
|
|
parser.add_argument("--max_new_tokens", type=int, default=30)
|
|
parser.add_argument("--do_sample", action="store_true", default=False)
|
|
parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
|
|
parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
|
|
parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
|
|
return parser
|