fix sft single turn inference example (#5416)

pull/5424/head^2
Camille Zhong 2024-03-01 17:27:50 +08:00 committed by GitHub
parent a1c6cdb189
commit 4b8312c08e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 2 deletions

View File

@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
model.to(device)
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
except OSError:
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
@ -29,6 +29,7 @@ def generate(args):
if args.prompt_style == "sft":
conversation = default_conversation.copy()
conversation.append_message("Human", args.input_txt)
conversation.append_message("Assistant", None)
input_txt = conversation.get_prompt()
else:
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
@ -46,7 +47,7 @@ def generate(args):
num_return_sequences=1,
)
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
return response