@ -1,10 +1,8 @@
import argparse
import inspect
import json
import math
import os
import resource
import sys
from contextlib import nullcontext
import torch
@ -26,8 +24,6 @@ logger = get_dist_logger()
def train ( args ) :
print ( colossalai . __version__ , inspect . getfile ( colossalai ) )
print ( sys . executable )
# check lora compatibility
if " gemini " in args . plugin and args . lora_rank > 0 :
raise ValueError ( " LoRA is not supported in GeminiPlugin. Please use other plugin " )
@ -44,38 +40,19 @@ def train(args):
# ==============================
init_ctx = nullcontext ( )
with init_ctx :
model = AutoModelForCausalLM . from_pretrained (
args . pretrain ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
trust_remote_code = True ,
)
# check if the hybrid parallel plugin is compatible with the model
try :
from colossalai . shardformer . policies . auto_policy import get_autopolicy
policy = get_autopolicy ( model )
if policy is not None :
if args . plugin in [ " zero2 " , " zero2_cpu " ] :
# if compatible, set the plugin to hybrid, which use colo-attention
args . plugin = " 3d "
args . zero_stage = 2
if args . plugin == " zero2_cpu " :
args . zero_cpu_offload = True
else :
args . zero_cpu_offload = False
logger . info (
f " Model is compatible with hybrid parallel plugin, set plugin to { args . plugin } with zero_stage { args . zero_stage } and zero_cpu_offload { args . zero_cpu_offload } "
)
except NotImplementedError :
logger . warning ( f " Unable to find a policy for the model, use { args . plugin } plugin instead " )
if args . use_flash_attn :
del model
model = AutoModelForCausalLM . from_pretrained (
args . pretrain ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
attn_implementation = " flash_attention_2 " ,
trust_remote_code = True ,
)
else :
model = AutoModelForCausalLM . from_pretrained (
args . pretrain ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
trust_remote_code = True ,
)
if args . lora_rank > 0 :
model = convert_to_lora_module ( model , args . lora_rank , lora_train_bias = args . lora_train_bias )