fix sft single turn inference example (#5416)

pull/5424/head^2
Camille Zhong 9 months ago committed by GitHub
parent a1c6cdb189
commit 4b8312c08e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
model.to(device) model.to(device)
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
except OSError: except OSError:
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.") raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
@ -29,6 +29,7 @@ def generate(args):
if args.prompt_style == "sft": if args.prompt_style == "sft":
conversation = default_conversation.copy() conversation = default_conversation.copy()
conversation.append_message("Human", args.input_txt) conversation.append_message("Human", args.input_txt)
conversation.append_message("Assistant", None)
input_txt = conversation.get_prompt() input_txt = conversation.get_prompt()
else: else:
BASE_INFERENCE_SUFFIX = "\n\n->\n\n" BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
@ -46,7 +47,7 @@ def generate(args):
num_return_sequences=1, num_return_sequences=1,
) )
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True) response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}") logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
return response return response

Loading…
Cancel
Save