[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
from coati . dataset import DataCollatorForKTODataset , StatefulDistributedSampler , load_tokenized_dataset
from coati . models import LoraConfig , convert_to_lora_module , disable_dropout
from coati . trainer import KTOTrainer
from coati . utils import load_checkpoint
from transformers import AutoModelForCausalLM , AutoTokenizer
import colossalai
from colossalai . booster import Booster
from colossalai . booster . plugin import GeminiPlugin , HybridParallelPlugin , LowLevelZeroPlugin
from colossalai . cluster import DistCoordinator
from colossalai . logging import get_dist_logger
from colossalai . nn . lr_scheduler import CosineAnnealingWarmupLR
from colossalai . nn . optimizer import HybridAdam
logger = get_dist_logger ( )
def train ( args ) :
lora_config = None
if args . lora_config is not None :
lora_config = LoraConfig . from_file ( args . lora_config )
# check lora compatibility
if " gemini " in args . plugin and lora_config is not None and lora_config . r > 0 :
raise ValueError ( " LoRA is not supported in GeminiPlugin. Please use other plugin " )
if args . plugin == " gemini_auto " and args . accumulation_steps > 1 :
raise ValueError ( " Gradient accumulation is not supported in GeminiPlugin. Please use other plugin " )
# ==============================
# Initialize Distributed Training
# ==============================
colossalai . launch_from_torch ( )
coordinator = DistCoordinator ( )
# ==============================
# Initialize Booster
# ==============================
if args . plugin == " ddp " :
"""
Default torch ddp plugin without any acceleration , for
debugging purpose acceleration , for debugging purpose
"""
plugin = TorchDDPPlugin ( find_unused_parameters = True )
elif args . plugin == " gemini " :
plugin = GeminiPlugin (
precision = args . mixed_precision ,
placement_policy = " static " ,
initial_scale = 2 * * 16 ,
max_norm = args . grad_clip ,
enable_gradient_accumulation = True ,
enable_flash_attention = args . use_flash_attn ,
)
elif args . plugin == " gemini_auto " :
plugin = GeminiPlugin (
precision = args . mixed_precision ,
placement_policy = " auto " ,
initial_scale = 2 * * 16 ,
max_norm = args . grad_clip ,
enable_flash_attention = args . use_flash_attn ,
)
elif args . plugin == " zero2 " :
plugin = LowLevelZeroPlugin (
stage = 2 ,
precision = args . mixed_precision ,
initial_scale = 2 * * 16 ,
max_norm = args . grad_clip ,
)
elif args . plugin == " zero2_cpu " :
plugin = LowLevelZeroPlugin (
stage = 2 ,
precision = args . mixed_precision ,
initial_scale = 2 * * 16 ,
cpu_offload = True ,
max_norm = args . grad_clip ,
)
elif args . plugin == " 3d " :
plugin = HybridParallelPlugin (
tp_size = args . tp ,
pp_size = args . pp ,
sp_size = args . sp ,
sequence_parallelism_mode = args . sp_mode ,
zero_stage = args . zero_stage ,
enable_flash_attention = args . use_flash_attn ,
enable_sequence_parallelism = args . enable_sequence_parallelism ,
cpu_offload = True if args . zero_stage > = 1 and args . zero_cpu_offload else False ,
parallel_output = False ,
max_norm = args . grad_clip ,
precision = args . mixed_precision ,
)
else :
raise ValueError ( f " Unknown plugin { args . plugin } " )
booster = Booster ( plugin = plugin )
ref_booster = Booster ( plugin = plugin )
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext ( )
with init_ctx :
if args . use_flash_attn :
model = AutoModelForCausalLM . from_pretrained (
args . pretrain ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
use_flash_attention_2 = True ,
)
coordinator . print_on_master ( msg = " Flash-attention enabled successfully " )
else :
model = AutoModelForCausalLM . from_pretrained ( args . pretrain )
if args . use_flash_attn :
ref_model = AutoModelForCausalLM . from_pretrained (
args . pretrain ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
use_flash_attention_2 = True ,
)
else :
ref_model = AutoModelForCausalLM . from_pretrained ( args . pretrain )
if args . lora_config is not None :
model = convert_to_lora_module ( model , lora_config = lora_config )
for name , module in model . named_modules ( ) :
if " norm " in name or " gate " in name :
module = module . to ( torch . float32 )
disable_dropout ( ref_model )
disable_dropout ( model )
if args . grad_checkpoint :
# Note, for some models, lora may not be compatible with gradient checkpointing
model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = { " use_reentrant " : False } )
coordinator . print_on_master ( msg = " Gradient checkpointing enabled successfully " )
# configure tokenizer
tokenizer_dir = args . tokenizer_dir if args . tokenizer_dir is not None else args . pretrain
tokenizer = AutoTokenizer . from_pretrained ( tokenizer_dir , use_fast = False , trust_remote_code = True )
if hasattr ( tokenizer , " pad_token " ) and hasattr ( tokenizer , " eos_token " ) and tokenizer . eos_token is not None :
try :
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer . pad_token = tokenizer . eos_token
except AttributeError as e :
logger . warning ( f " Unable to set pad token to eos token, { str ( e ) } " )
if not hasattr ( tokenizer , " pad_token " ) or tokenizer . pad_token is None :
logger . warning (
" The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them. "
)
tokenizer . add_bos_token = False
tokenizer . add_eos_token = False
# configure optimizer
optim = HybridAdam (
model_params = model . parameters ( ) ,
lr = args . lr ,
betas = ( 0.9 , 0.95 ) ,
weight_decay = args . weight_decay ,
adamw_mode = True ,
)
# configure dataset
coordinator . print_on_master ( f " Load dataset: { args . dataset } " )
mode_map = { " train " : " train " , " valid " : " validation " , " test " : " test " }
train_dataset = load_tokenized_dataset ( dataset_paths = args . dataset , mode = " train " , mode_map = mode_map )
num_desirable = 0
num_undesirable = 0
for i in range ( len ( train_dataset ) ) :
if train_dataset [ i ] [ " label " ] :
num_desirable + = 1
else :
num_undesirable + = 1
logger . info ( f " Dataset Statistics: \n Desirable: { num_desirable } \n Undesirable: { num_undesirable } " )
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
actual_ratio = ( args . desirable_weight * num_desirable ) / ( args . undesirable_weight * num_undesirable )
if actual_ratio < 1 or actual_ratio > 4 / 3 :
if not args . auto_weight :
raise AssertionError (
f " Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: { actual_ratio } , please increase/decrease desirable weight or decrease/increase undesirable weight. "
)
else :
args . desirable_weight = args . desirable_weight / actual_ratio
coordinator . print_on_master (
f " Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: { actual_ratio } , auto weight is enabled, set desirable weight to { args . desirable_weight } and undesirable weight to { args . undesirable_weight } "
)
data_collator = DataCollatorForKTODataset ( tokenizer = tokenizer , max_length = args . max_length )
train_dataloader = plugin . prepare_dataloader (
dataset = train_dataset ,
batch_size = args . batch_size ,
shuffle = True ,
drop_last = True ,
collate_fn = data_collator ,
distributed_sampler_cls = StatefulDistributedSampler ,
)
eval_dataloader = None
if args . eval_dataset :
eval_dataset = load_tokenized_dataset ( dataset_paths = args . eval_dataset , mode = " dev " )
eval_data_collator = DataCollatorForKTODataset ( tokenizer = tokenizer , max_length = args . max_length )
eval_dataloader = plugin . prepare_dataloader (
dataset = eval_dataset ,
batch_size = args . batch_size ,
shuffle = True ,
drop_last = True ,
collate_fn = eval_data_collator ,
distributed_sampler_cls = StatefulDistributedSampler ,
)
else :
logger . warning ( " No evaluation dataset is provided, skip evaluation " )
num_update_steps_per_epoch = len ( train_dataloader ) / / args . accumulation_steps
if args . warmup_steps is None :
args . warmup_steps = int ( args . max_epochs * 0.025 * ( len ( train_dataloader ) / / args . accumulation_steps ) )
coordinator . print_on_master ( f " Warmup steps is set to { args . warmup_steps } " )
lr_scheduler = CosineAnnealingWarmupLR (
optimizer = optim ,
total_steps = args . max_epochs * num_update_steps_per_epoch ,
warmup_steps = args . warmup_steps ,
eta_min = 0.1 * args . lr ,
)
default_dtype = torch . float16 if args . mixed_precision == " fp16 " else torch . bfloat16
torch . set_default_dtype ( default_dtype )
model , optim , _ , train_dataloader , lr_scheduler = booster . boost (
model = model ,
optimizer = optim ,
lr_scheduler = lr_scheduler ,
dataloader = train_dataloader ,
)
if ref_model is not None :
ref_model , _ , _ , _ , _ = ref_booster . boost ( model = ref_model , dataloader = train_dataloader )
torch . set_default_dtype ( torch . float )
coordinator . print_on_master ( f " Booster init max CUDA memory: { torch . cuda . max_memory_allocated ( ) / 1024 * * 2 : .2f } MB " )
coordinator . print_on_master (
f " Booster init max CPU memory: { resource . getrusage ( resource . RUSAGE_SELF ) . ru_maxrss / 1024 : .2f } MB "
)
start_epoch = 0
sampler_start_idx = 0
start_step = 0
if args . checkpoint_path is not None :
if " modeling " in args . checkpoint_path :
coordinator . print_on_master ( f " Continued pretrain from checkpoint { args . checkpoint_path } " )
booster . load_model ( model , args . checkpoint_path )
else :
coordinator . print_on_master ( f " Load model checkpoint from { args . checkpoint_path } " )
start_epoch , start_step , sampler_start_idx = load_checkpoint (
load_dir = args . checkpoint_path ,
booster = booster ,
model = model ,
optimizer = optim ,
lr_scheduler = lr_scheduler ,
)
assert isinstance ( train_dataloader . sampler , StatefulDistributedSampler )
train_dataloader . sampler . set_start_index ( start_index = sampler_start_idx )
coordinator . print_on_master (
f " Loaded checkpoint { args . checkpoint_path } at epoch { start_epoch } step { start_step } "
)
coordinator . print_on_master ( f " Loaded sample at index { sampler_start_idx } " )
coordinator . print_on_master (
f " Checkpoint loaded max CUDA memory: { torch . cuda . max_memory_allocated ( ) / 1024 * * 2 : .2f } MB "
)
coordinator . print_on_master (
f " Checkpoint loaded CUDA memory: { torch . cuda . memory_allocated ( ) / 1024 * * 2 : .2f } MB "
)
coordinator . print_on_master (
f " Checkpoint loaded max CPU memory: { resource . getrusage ( resource . RUSAGE_SELF ) . ru_maxrss / 1024 : .2f } MB "
)
trainer = KTOTrainer (
actor = model ,
ref_model = ref_model ,
booster = booster ,
actor_optim = optim ,
actor_lr_scheduler = lr_scheduler ,
tokenizer = tokenizer ,
max_epochs = args . max_epochs ,
accumulation_steps = args . accumulation_steps ,
start_epoch = start_epoch ,
save_interval = args . save_interval ,
save_dir = args . save_dir ,
coordinator = coordinator ,
beta = args . beta ,
desirable_weight = args . desirable_weight ,
undesirable_weight = args . undesirable_weight ,
2024-08-02 06:51:38 +00:00
apply_loss_mask = not args . disable_loss_mask ,
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
)
trainer . fit (
train_preference_dataloader = train_dataloader ,
eval_preference_dataloader = eval_dataloader ,
log_dir = args . log_dir ,
use_wandb = args . use_wandb ,
)
if lora_config is not None and lora_config . r > 0 :
# NOTE: set model to eval to merge LoRA weights
model . eval ( )
# save model checkpoint after fitting on only rank0
if args . save_dir is not None :
coordinator . print_on_master ( " Start saving final model checkpoint " )
booster . save_model ( model , os . path . join ( args . save_dir , " modeling " ) , shard = True )
coordinator . print_on_master (
f " Saved final model checkpoint at epoch { args . max_epochs } at folder { args . save_dir } "
)
coordinator . print_on_master ( f " Max CUDA memory usage: { torch . cuda . max_memory_allocated ( ) / 1024 * * 2 : .2f } MB " )
if __name__ == " __main__ " :
# ==============================
# Parse Arguments
# ==============================
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --plugin " ,
type = str ,
default = " gemini " ,
choices = [ " gemini " , " gemini_auto " , " zero2 " , " zero2_cpu " , " 3d " ] ,
help = " Choose which plugin to use " ,
)
parser . add_argument ( " --grad_clip " , type = float , default = 1.0 , help = " Gradient clipping value " )
parser . add_argument ( " --weight_decay " , type = float , default = 0.1 , help = " Weight decay " )
parser . add_argument ( " --warmup_steps " , type = int , default = None , help = " Warmup steps " )
parser . add_argument ( " --tp " , type = int , default = 1 )
parser . add_argument ( " --pp " , type = int , default = 1 )
parser . add_argument ( " --sp " , type = int , default = 1 )
parser . add_argument ( " --beta " , type = float , default = 0.1 , help = " beta in KTO loss " )
parser . add_argument ( " --desirable_weight " , type = float , default = 1.0 , help = " desirable_weight in KTO loss " )
parser . add_argument ( " --undesirable_weight " , type = float , default = 1.0 , help = " undesirable_weight in KTO loss " )
2024-08-02 06:51:38 +00:00
parser . add_argument ( " --disable_loss_mask " , default = False , action = " store_true " )
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
parser . add_argument ( " --enable_sequence_parallelism " , default = False , action = " store_true " )
parser . add_argument ( " --zero_stage " , type = int , default = 0 , help = " Zero stage " , choices = [ 0 , 1 , 2 ] )
parser . add_argument ( " --zero_cpu_offload " , default = False , action = " store_true " )
parser . add_argument ( " --sp_mode " , type = str , default = " split_gather " , choices = [ " split_gather " , " ring " , " all_to_all " ] )
parser . add_argument ( " --pretrain " , type = str , default = None )
parser . add_argument ( " --tokenizer_dir " , type = str , default = None )
parser . add_argument ( " --dataset " , nargs = " + " , default = [ ] )
parser . add_argument ( " --eval_dataset " , nargs = " + " , default = [ ] )
parser . add_argument (
" --checkpoint_path " , type = str , default = None , help = " Checkpoint path if need to resume training form a checkpoint "
)
parser . add_argument ( " --config_file " , type = str , default = None , help = " Config file " )
parser . add_argument ( " --save_dir " , type = str , default = None )
parser . add_argument ( " --max_length " , type = int , default = 2048 , help = " Model max length " )
parser . add_argument ( " --max_epochs " , type = int , default = 3 )
parser . add_argument ( " --batch_size " , type = int , default = 4 )
parser . add_argument ( " --mixed_precision " , type = str , default = " fp16 " , choices = [ " fp16 " , " bf16 " ] , help = " Mixed precision " )
parser . add_argument ( " --lora_config " , type = str , default = None , help = " low-rank adaptation config file path " )
parser . add_argument ( " --save_interval " , type = int , default = 1000 , help = " number of step between two checkpoints " )
parser . add_argument ( " --auto_weight " , default = False , action = " store_true " )
parser . add_argument ( " --lr " , type = float , default = 5e-6 )
parser . add_argument ( " --accumulation_steps " , type = int , default = 8 )
parser . add_argument ( " --log_dir " , default = None , type = str )
parser . add_argument ( " --use_wandb " , default = False , action = " store_true " )
parser . add_argument ( " --grad_checkpoint " , default = False , action = " store_true " )
parser . add_argument ( " --use_flash_attn " , default = False , action = " store_true " )
args = parser . parse_args ( )
if args . config_file is not None :
os . makedirs ( os . path . dirname ( args . config_file ) , exist_ok = True )
with open ( args . config_file , " w " ) as f :
json . dump ( args . __dict__ , f , indent = 4 )
2024-08-17 09:34:18 +00:00
train ( args )