update sft trainning script

pull/5759/head
YeAnbang 6 months ago
parent 2abdede1d7
commit 84eab13078

@ -1,10 +1,8 @@
import argparse
import inspect
import json
import math
import os
import resource
import sys
from contextlib import nullcontext
import torch
@ -26,8 +24,6 @@ logger = get_dist_logger()
def train(args):
print(colossalai.__version__, inspect.getfile(colossalai))
print(sys.executable)
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
@ -44,38 +40,19 @@ def train(args):
# ==============================
init_ctx = nullcontext()
with init_ctx:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# check if the hybrid parallel plugin is compatible with the model
try:
from colossalai.shardformer.policies.auto_policy import get_autopolicy
policy = get_autopolicy(model)
if policy is not None:
if args.plugin in ["zero2", "zero2_cpu"]:
# if compatible, set the plugin to hybrid, which use colo-attention
args.plugin = "3d"
args.zero_stage = 2
if args.plugin == "zero2_cpu":
args.zero_cpu_offload = True
else:
args.zero_cpu_offload = False
logger.info(
f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}"
)
except NotImplementedError:
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
if args.use_flash_attn:
del model
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)

Loading…
Cancel
Save