fix sft single turn inference example (#5416)

pull/5424/head^2
Camille Zhong 2024-03-01 17:27:50 +08:00 committed by GitHub
parent a1c6cdb189
commit 4b8312c08e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 2 deletions

View File

@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
model.to(device) model.to(device)
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
except OSError: except OSError:
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.") raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
@ -29,6 +29,7 @@ def generate(args):
if args.prompt_style == "sft": if args.prompt_style == "sft":
conversation = default_conversation.copy() conversation = default_conversation.copy()
conversation.append_message("Human", args.input_txt) conversation.append_message("Human", args.input_txt)
conversation.append_message("Assistant", None)
input_txt = conversation.get_prompt() input_txt = conversation.get_prompt()
else: else:
BASE_INFERENCE_SUFFIX = "\n\n->\n\n" BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
@ -46,7 +47,7 @@ def generate(args):
num_return_sequences=1, num_return_sequences=1,
) )
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True) response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}") logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
return response return response