mirror of https://github.com/hpcaitech/ColossalAI
fix sft single turn inference example (#5416)
parent
a1c6cdb189
commit
4b8312c08e
|
@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
|
|||
model.to(device)
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
|
||||
except OSError:
|
||||
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
|
||||
|
||||
|
@ -29,6 +29,7 @@ def generate(args):
|
|||
if args.prompt_style == "sft":
|
||||
conversation = default_conversation.copy()
|
||||
conversation.append_message("Human", args.input_txt)
|
||||
conversation.append_message("Assistant", None)
|
||||
input_txt = conversation.get_prompt()
|
||||
else:
|
||||
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
|
||||
|
@ -46,7 +47,7 @@ def generate(args):
|
|||
num_return_sequences=1,
|
||||
)
|
||||
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
|
||||
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
|
||||
logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
|
||||
return response
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue