[llama] fix training and inference scripts (#5384)

* [llama] refactor inference example to fit sft

* [llama] fix training script to fit gemini

* [llama] fix inference script
pull/5389/head
Hongxin Liu 2024-02-19 16:41:04 +08:00 committed by GitHub
parent adae123df3
commit 7303801854
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 30 deletions

View File

@ -1,17 +1,16 @@
import argparse
import os
import torch
from colossal_llama2.dataset.conversation import default_conversation
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.logging import get_dist_logger
from transformers import AutoTokenizer, AutoModelForCausalLM
logger = get_dist_logger()
def load_model(model_path, device="cuda", **kwargs):
logger.info(
"Please check whether the tokenizer and model weights are properly stored in the same folder."
)
logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.")
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
model.to(device)
@ -27,31 +26,50 @@ def load_model(model_path, device="cuda", **kwargs):
def generate(args):
model, tokenizer = load_model(model_path=args.model_path, device=args.device)
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
if args.prompt_style == "sft":
conversation = default_conversation.copy()
conversation.append_message("Human", args.input_txt)
input_txt = conversation.get_prompt()
else:
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device)
output = model.generate(**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
num_return_sequences=1)
response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):]
inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
num_input_tokens = inputs["input_ids"].shape[-1]
output = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
num_return_sequences=1,
)
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model")
parser.add_argument('--device', type=str, default="cuda:0", help="Set the device")
parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt")
parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling")
parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value")
parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering")
parser.add_argument('--top_p', type=int, default=0.95, help="Set top_p value for generation")
parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model")
parser.add_argument(
"--model_path",
type=str,
default="hpcai-tech/Colossal-LLaMA-2-7b-base",
help="HF repo name or local path of the model",
)
parser.add_argument("--device", type=str, default="cuda:0", help="Set the device")
parser.add_argument(
"--max_new_tokens",
type=int,
default=512,
help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
)
parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling")
parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
parser.add_argument("--top_p", type=int, default=0.95, help="Set top_p value for generation")
parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model")
parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt")
args = parser.parse_args()
generate(args)

View File

@ -154,6 +154,7 @@ def main() -> None:
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -161,6 +162,7 @@ def main() -> None:
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(

View File

@ -726,11 +726,13 @@ class GeminiDDP(ModelWrapper):
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
del temp_chunk
if self.reuse_fp16_chunk:
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.payload.copy_(chunk_32.payload)
# sync running weights and master weights
if self.master_weights:
for loaded_chunk in chunk_list:
paired_chunk = loaded_chunk.paired_chunk
assert paired_chunk is not None
paired_chunk.payload.copy_(loaded_chunk.payload)
for name, buf in persistent_buffers.items():
if buf is not None: