mirror of https://github.com/hpcaitech/ColossalAI
[llama] fix training and inference scripts (#5384)
* [llama] refactor inference example to fit sft * [llama] fix training script to fit gemini * [llama] fix inference scriptpull/5389/head
parent
adae123df3
commit
7303801854
|
@ -1,17 +1,16 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from colossal_llama2.dataset.conversation import default_conversation
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path, device="cuda", **kwargs):
|
def load_model(model_path, device="cuda", **kwargs):
|
||||||
logger.info(
|
logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.")
|
||||||
"Please check whether the tokenizer and model weights are properly stored in the same folder."
|
|
||||||
)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
@ -27,31 +26,50 @@ def load_model(model_path, device="cuda", **kwargs):
|
||||||
def generate(args):
|
def generate(args):
|
||||||
model, tokenizer = load_model(model_path=args.model_path, device=args.device)
|
model, tokenizer = load_model(model_path=args.model_path, device=args.device)
|
||||||
|
|
||||||
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
|
if args.prompt_style == "sft":
|
||||||
input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
|
conversation = default_conversation.copy()
|
||||||
|
conversation.append_message("Human", args.input_txt)
|
||||||
|
input_txt = conversation.get_prompt()
|
||||||
|
else:
|
||||||
|
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
|
||||||
|
input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
|
||||||
|
|
||||||
inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device)
|
inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
|
||||||
output = model.generate(**inputs,
|
num_input_tokens = inputs["input_ids"].shape[-1]
|
||||||
max_new_tokens=args.max_new_tokens,
|
output = model.generate(
|
||||||
do_sample=args.do_sample,
|
**inputs,
|
||||||
temperature=args.temperature,
|
max_new_tokens=args.max_new_tokens,
|
||||||
top_k=args.top_k,
|
do_sample=args.do_sample,
|
||||||
top_p=args.top_p,
|
temperature=args.temperature,
|
||||||
num_return_sequences=1)
|
top_k=args.top_k,
|
||||||
response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):]
|
top_p=args.top_p,
|
||||||
|
num_return_sequences=1,
|
||||||
|
)
|
||||||
|
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
|
||||||
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
|
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
|
parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
|
||||||
parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model")
|
parser.add_argument(
|
||||||
parser.add_argument('--device', type=str, default="cuda:0", help="Set the device")
|
"--model_path",
|
||||||
parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt")
|
type=str,
|
||||||
parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling")
|
default="hpcai-tech/Colossal-LLaMA-2-7b-base",
|
||||||
parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value")
|
help="HF repo name or local path of the model",
|
||||||
parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering")
|
)
|
||||||
parser.add_argument('--top_p', type=int, default=0.95, help="Set top_p value for generation")
|
parser.add_argument("--device", type=str, default="cuda:0", help="Set the device")
|
||||||
parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model")
|
parser.add_argument(
|
||||||
|
"--max_new_tokens",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
|
||||||
|
)
|
||||||
|
parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling")
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
|
||||||
|
parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
|
||||||
|
parser.add_argument("--top_p", type=int, default=0.95, help="Set top_p value for generation")
|
||||||
|
parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model")
|
||||||
|
parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
generate(args)
|
generate(args)
|
||||||
|
|
|
@ -154,6 +154,7 @@ def main() -> None:
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
|
@ -161,6 +162,7 @@ def main() -> None:
|
||||||
placement_policy="auto",
|
placement_policy="auto",
|
||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
)
|
)
|
||||||
elif args.plugin == "zero2":
|
elif args.plugin == "zero2":
|
||||||
plugin = LowLevelZeroPlugin(
|
plugin = LowLevelZeroPlugin(
|
||||||
|
|
|
@ -726,11 +726,13 @@ class GeminiDDP(ModelWrapper):
|
||||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
||||||
|
|
||||||
del temp_chunk
|
del temp_chunk
|
||||||
if self.reuse_fp16_chunk:
|
|
||||||
for chunk_32 in chunk_list:
|
# sync running weights and master weights
|
||||||
chunk_16 = chunk_32.paired_chunk
|
if self.master_weights:
|
||||||
assert chunk_16 is not None
|
for loaded_chunk in chunk_list:
|
||||||
chunk_16.payload.copy_(chunk_32.payload)
|
paired_chunk = loaded_chunk.paired_chunk
|
||||||
|
assert paired_chunk is not None
|
||||||
|
paired_chunk.payload.copy_(loaded_chunk.payload)
|
||||||
|
|
||||||
for name, buf in persistent_buffers.items():
|
for name, buf in persistent_buffers.items():
|
||||||
if buf is not None:
|
if buf is not None:
|
||||||
|
|
Loading…
Reference in New Issue