update sft trainning script

pull/5759/head
YeAnbang 2024-06-11 02:44:20 +00:00
parent 2abdede1d7
commit 84eab13078
1 changed files with 6 additions and 29 deletions

View File

@ -1,10 +1,8 @@
import argparse import argparse
import inspect
import json import json
import math import math
import os import os
import resource import resource
import sys
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
@ -26,8 +24,6 @@ logger = get_dist_logger()
def train(args): def train(args):
print(colossalai.__version__, inspect.getfile(colossalai))
print(sys.executable)
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0: if "gemini" in args.plugin and args.lora_rank > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
@ -44,38 +40,19 @@ def train(args):
# ============================== # ==============================
init_ctx = nullcontext() init_ctx = nullcontext()
with init_ctx: with init_ctx:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# check if the hybrid parallel plugin is compatible with the model
try:
from colossalai.shardformer.policies.auto_policy import get_autopolicy
policy = get_autopolicy(model)
if policy is not None:
if args.plugin in ["zero2", "zero2_cpu"]:
# if compatible, set the plugin to hybrid, which use colo-attention
args.plugin = "3d"
args.zero_stage = 2
if args.plugin == "zero2_cpu":
args.zero_cpu_offload = True
else:
args.zero_cpu_offload = False
logger.info(
f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}"
)
except NotImplementedError:
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
if args.use_flash_attn: if args.use_flash_attn:
del model
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
args.pretrain, args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
trust_remote_code=True, trust_remote_code=True,
) )
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
if args.lora_rank > 0: if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)