mirror of https://github.com/hpcaitech/ColossalAI
fix sft single turn inference example (#5416)
parent
a1c6cdb189
commit
4b8312c08e
|
@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
|
||||||
except OSError:
|
except OSError:
|
||||||
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
|
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ def generate(args):
|
||||||
if args.prompt_style == "sft":
|
if args.prompt_style == "sft":
|
||||||
conversation = default_conversation.copy()
|
conversation = default_conversation.copy()
|
||||||
conversation.append_message("Human", args.input_txt)
|
conversation.append_message("Human", args.input_txt)
|
||||||
|
conversation.append_message("Assistant", None)
|
||||||
input_txt = conversation.get_prompt()
|
input_txt = conversation.get_prompt()
|
||||||
else:
|
else:
|
||||||
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
|
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
|
||||||
|
@ -46,7 +47,7 @@ def generate(args):
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
)
|
)
|
||||||
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
|
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
|
||||||
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
|
logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue