@ -21,6 +21,7 @@ from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama . utils . froze import freeze_non_embeds_parameters
from colossal_llama . utils . neftune_patch import activate_neftune , deactivate_neftune
from colossal_llama . utils . utils import all_reduce_mean , format_numel_str , get_model_numel
from peft import LoraConfig
from torch . utils . tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForCausalLM , AutoTokenizer
@ -65,7 +66,7 @@ def train(args) -> None:
initial_scale = 2 * * 16 ,
max_norm = args . grad_clip ,
enable_gradient_accumulation = ( args . accumulation_steps > 1 ) ,
enable_fused_normalization = torch . cuda . is_available ( ) ,
enable_fused_normalization = get_accelerator ( ) . is_available ( ) ,
enable_flash_attention = args . use_flash_attn ,
)
elif args . plugin == " gemini_auto " :
@ -75,7 +76,7 @@ def train(args) -> None:
initial_scale = 2 * * 16 ,
max_norm = args . grad_clip ,
enable_gradient_accumulation = ( args . accumulation_steps > 1 ) ,
enable_fused_normalization = torch . cuda . is_available ( ) ,
enable_fused_normalization = get_accelerator ( ) . is_available ( ) ,
enable_flash_attention = args . use_flash_attn ,
)
elif args . plugin == " zero2 " :
@ -101,10 +102,9 @@ def train(args) -> None:
sequence_parallelism_mode = args . sp_mode ,
zero_stage = args . zero_stage ,
enable_flash_attention = args . use_flash_attn ,
enable_fused_normalization = torch . cuda . is_available ( ) ,
enable_fused_normalization = get_accelerator ( ) . is_available ( ) ,
enable_sequence_parallelism = args . enable_sequence_parallelism ,
cpu_offload = True if args . zero_stage > = 1 and args . zero_cpu_offload else False ,
parallel_output = False ,
max_norm = args . grad_clip ,
precision = args . mixed_precision ,
microbatch_size = args . microbatch_size ,
@ -117,11 +117,17 @@ def train(args) -> None:
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer . from_pretrained ( args . pretrained )
tokenizer = AutoTokenizer . from_pretrained ( args . pretrained , trust_remote_code = True )
if args . pad_token == " eos " :
tokenizer . pad_token = tokenizer . eos_token
try :
tokenizer . pad_token = tokenizer . eos_token
except AttributeError :
coordinator . print_on_master ( f " pad_token can ' t be set " )
elif args . pad_token == " unk " :
tokenizer . pad_token = tokenizer . unk_token
try :
tokenizer . pad_token = tokenizer . unk_token
except AttributeError :
coordinator . print_on_master ( f " pad_token can ' t be set " )
tokenizer . add_bos_token = False
tokenizer . add_eos_token = False
@ -164,33 +170,31 @@ def train(args) -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext ( default_device = get_current_device ( ) )
if isinstance ( plugin , ( GeminiPlugin , HybridParallelPlugin ) )
if isinstance ( plugin , ( GeminiPlugin , HybridParallelPlugin ) ) and args . lora_rank == 0
else nullcontext ( )
)
with init_ctx :
if args . use_flash_attn :
model = AutoModelForCausalLM . from_pretrained (
args . pretrained ,
attn_implementation = " flash_attention_2 " ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
trust_remote_code = True ,
)
else :
model = AutoModelForCausalLM . from_pretrained (
args . pretrained ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
trust_remote_code = True ,
)
model = AutoModelForCausalLM . from_pretrained (
args . pretrained ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
trust_remote_code = True ,
)
# Freeze part of parameters.
if args . freeze_non_embeds_params :
freeze_non_embeds_parameters ( model = model )
if args . lora_rank > 0 :
lora_config = LoraConfig ( task_type = " CAUSAL_LM " , r = args . lora_rank , lora_alpha = 32 , lora_dropout = 0.1 )
model = booster . enable_lora ( model , lora_config = lora_config )
# this is essential, otherwise the grad checkpoint will not work.
model . train ( )
if args . use_grad_checkpoint :
model . gradient_checkpointing_enable ( )
model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = { " use_reentrant " : False } )
coordinator . print_on_master ( msg = " Gradient checkpointing enabled successfully " )
model_numel = get_model_numel ( model )
@ -327,6 +331,7 @@ def train(args) -> None:
step = step + 1 ,
batch_size = args . batch_size ,
coordinator = coordinator ,
use_lora = ( args . lora_rank > 0 ) ,
)
coordinator . print_on_master (
f " Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args . save_dir } "
@ -371,44 +376,45 @@ def train(args) -> None:
total_loss . fill_ ( 0.0 )
pbar . update ( )
# Save modeling.
save_model_condition = (
args . save_interval > 0 and ( step + 1 ) % ( args . save_interval * args . accumulation_steps ) == 0
)
# Save modeling.
save_model_condition = (
args . save_interval > 0 and ( step + 1 ) % ( args . save_interval * args . accumulation_steps ) == 0
)
if not args . skip_save_each_epoch :
save_model_condition = save_model_condition or ( step + 1 ) == len ( dataloader )
if not args . skip_save_each_epoch :
save_model_condition = save_model_condition or ( step + 1 ) == len ( dataloader )
if save_model_condition and not args . benchmark :
coordinator . print_on_master ( " \n Start saving model checkpoint with running states " )
if save_model_condition and not args . benchmark :
coordinator . print_on_master ( " \n Start saving model checkpoint with running states " )
if args . use_neft :
coordinator . print_on_master ( " Deactivate NEFTune before saving model. " )
deactivate_neftune ( model , handle )
if args . use_neft :
coordinator . print_on_master ( " Deactivate NEFTune before saving model. " )
deactivate_neftune ( model , handle )
accelerator . empty_cache ( )
save_checkpoint (
save_dir = args . save_dir ,
booster = booster ,
model = model ,
optimizer = optimizer ,
lr_scheduler = lr_scheduler ,
epoch = epoch ,
step = step + 1 ,
batch_size = args . batch_size ,
coordinator = coordinator ,
)
coordinator . print_on_master (
f " Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args . save_dir } "
)
accelerator . empty_cache ( )
save_checkpoint (
save_dir = args . save_dir ,
booster = booster ,
model = model ,
optimizer = optimizer ,
lr_scheduler = lr_scheduler ,
epoch = epoch ,
step = step + 1 ,
batch_size = args . batch_size ,
coordinator = coordinator ,
use_lora = ( args . lora_rank > 0 ) ,
)
coordinator . print_on_master (
f " Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args . save_dir } "
)
if args . use_neft :
coordinator . print_on_master ( " Activate NEFTune. " )
model , handle = activate_neftune ( model )
if args . use_neft :
coordinator . print_on_master ( " Activate NEFTune. " )
model , handle = activate_neftune ( model )
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator . empty_cache ( )
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator . empty_cache ( )
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader . sampler . set_start_index ( start_index = 0 )
@ -522,6 +528,7 @@ if __name__ == "__main__":
parser . add_argument (
" --microbatch_size " , type = int , default = 1 , help = " Batch size for each process in PP, used for 3d plugin. "
)
parser . add_argument ( " --lora_rank " , type = int , default = 0 , help = " lora rank when using lora to train. " )
# Additional arguments for benchmark.
parser . add_argument ( " --num_samples " , type = int , default = 500 , help = " Number of samples for benchmarking. " )