mirror of https://github.com/InternLM/InternLM
Merge develop to main (#233)
* feat(utils/writer.py): support tensorboard writer (#63) * feat(utils/writer.py): support tensorboard writer * feat(utils/writer.py): add class comment --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> * [Develop] Pull Main Branch (#121) * fix/fix_submodule_err (#61) * fix/fix_submodule_err --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> * fix issue templates (#65) * fix(tokenizer): refactor tokenizer and update usage in readme (#51) * update tokenizer example * fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (#73) * fix a typo in readme * in order to find InternLMTokenizer, select a lower version of Transformers --------- Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> * [Doc] Add wechat and discord link in readme (#78) * Doc:add wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * [Docs]: add Japanese README (#43) * Add Japanese README * Update README-ja-JP.md replace message * Update README-ja-JP.md * add repetition_penalty in GenerationConfig in web_demo.py (#48) Co-authored-by: YWMditto <862779238@qq.com> * use fp16 in instruction (#80) * [Enchancement] add more options for issue template (#77) * [Enchancement] add more options for issue template * update qustion icon * fix link * Use tempfile for convert2hf.py (#23) Fix https://github.com/InternLM/InternLM/issues/50 * delete torch_dtype of README's example code (#100) * set the value of repetition_penalty to 1.0 to avoid random outputs (#99) * Update web_demo.py (#97) Remove meaningless log. * [Fix]Fix wrong string cutoff in the script for sft text tokenizing (#106) --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> Co-authored-by: Kai Chen <chenkaidev@gmail.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com> Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> Co-authored-by: vansin <msnode@163.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com> Co-authored-by: YWMditto <862779238@qq.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com> Co-authored-by: Shuo Zhang <zhangshuolove@live.com> Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * feat(core/scheduler): support pipeline parallel (#98) * feat(utils/writer.py): support tensorboard writer * feat(utils/writer.py): add class comment * feat(core): support pipeline parallel * fix(core): fix demo running error * feat(solver/optimizer): add pp zero optimizer * fix(solver/optimizer): fix word spelling error * feat(core/scheduler): add new dir scheduler in core/ * fix(core): fix ci lint error * feat(solver/optimizer): merge pp and nopp optimizer * doc(usage.md): update usage doc * feat(core/scheduler): support post func * feat(core/scheduler): add dtype para in pp sche and update func get_tensor_shape * feat(core/scheduler): add _load_micro_batch in base scheduler * feat(core/scheduler): support optimizer overlap communication in pp scheduler * feat(core/scheduler): delete data process func code * feat(core/trainer): schedule pre processing for all schedule --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: huangting.p <huangting@sensetime.com> * refactor(rotaryEmbedding): refactor forward (#120) * use fp16 in instruction (#80) * delete torch_dtype of README's example code (#100) * refactor the forward for rotary embedding --------- Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com> * feat(model/metrics.py): support calculating accuracy and perplexity m… (#91) * feat(model/metrics.py): support calculating accuracy and perplexity metrics * fix(model/metrics.py): fix import error * feat(train.py): minor update --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: huangting.p <huangting@sensetime.com> * fix(optimizer/util.py) change inf defination * [Dev] Pull Main (#139) * fix/fix_submodule_err (#61) * fix/fix_submodule_err --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> * fix issue templates (#65) * fix(tokenizer): refactor tokenizer and update usage in readme (#51) * update tokenizer example * fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (#73) * fix a typo in readme * in order to find InternLMTokenizer, select a lower version of Transformers --------- Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> * [Doc] Add wechat and discord link in readme (#78) * Doc:add wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * [Docs]: add Japanese README (#43) * Add Japanese README * Update README-ja-JP.md replace message * Update README-ja-JP.md * add repetition_penalty in GenerationConfig in web_demo.py (#48) Co-authored-by: YWMditto <862779238@qq.com> * use fp16 in instruction (#80) * [Enchancement] add more options for issue template (#77) * [Enchancement] add more options for issue template * update qustion icon * fix link * Use tempfile for convert2hf.py (#23) Fix https://github.com/InternLM/InternLM/issues/50 * delete torch_dtype of README's example code (#100) * set the value of repetition_penalty to 1.0 to avoid random outputs (#99) * Update web_demo.py (#97) Remove meaningless log. * [Fix]Fix wrong string cutoff in the script for sft text tokenizing (#106) * docs(install.md): update dependency package transformers version to >= 4.28.0 (#124) Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> * docs(LICENSE): add license (#125) * add license of colossalai and flash-attn * fix lint * modify the name * fix AutoModel map in convert2hf.py (#116) * variables are not printly as expect (#114) * feat(solver): fix code to adapt to torch2.0 and provide docker images (#128) * feat(solver): fix code to adapt to torch2.0 * docs(install.md): publish internlm environment image * docs(install.md): update dependency packages version * docs(install.md): update default image --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> * add demo test (#132) Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * fix web_demo cache accelerate (#133) * fix(hybrid_zero_optim.py): delete math import * Update embedding.py --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> Co-authored-by: Kai Chen <chenkaidev@gmail.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com> Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> Co-authored-by: vansin <msnode@163.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com> Co-authored-by: YWMditto <862779238@qq.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com> Co-authored-by: Shuo Zhang <zhangshuolove@live.com> Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: huangting4201 <1538303371@qq.com> Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com> Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com> * style(solver/optimizer/utils.py): fix lint error (#147) Co-authored-by: huangting.p <huangting@sensetime.com> * feat(*): support not-flash-attn for pp and no-pp (#145) * support not flash attention for no-pp * support pipeline * modify the config * refactor the code * refactor the code * remove some unnecessary code * fix(initialize/launch.py): set default value for use_flash_attn (#158) * add default for use_flash_attn * fix lint * feat(utils/logger.py): support uniscale logger (#152) * style(internlm): fix lint error * feat(utils/logger.py): support uniscale logger * fix(utils/logger.py): fix import circular error * feat(train.py): support dashboard metric panel and fix ci train config * fix(ci_scripts/train/slurm_train.sh): fix ci train error * fix(ci_scripts/train/torchrun.sh): fix ci train error * fix(ci_scripts/train): restore ci update * fix(config.json): delete alert webhook * feat(train.py): optimize func init logger * feat(config.json): delete config.json --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: huangting.p <huangting@sensetime.com> * feat(utils/evaluation.py): support evaluate (#154) * style(internlm): fix lint error * feat(utils/logger.py): support uniscale logger * fix(utils/logger.py): fix import circular error * feat(train.py): support dashboard metric panel and fix ci train config * fix(ci_scripts/train/slurm_train.sh): fix ci train error * fix(ci_scripts/train/torchrun.sh): fix ci train error * feat(utils/evaluation.py): support evaluate on validation dataset * fix(utils/evaluation.py): fix demo error * fix(ci_scripts/train/ci_7B_sft.py): fix ci train error * feat(initialize/launch.py): set default value for valid_bsz and valid_every * fix(ci_scripts/train): restore ci update * docs(configs/7B_sft.py): update comment for config * fix(config.json): delete config.json * fix evaluation bug in scheduler when use_flash_attn=False * feat(scheduler/no_pipeline_scheduler.py): support micro_bsz>1 in no pp * modify the jugement in pp and no-pp scheduler * modify the data_process_func in evaluation * fix bugs when use_flash_attn=False * rename symbol * feat(configs/7B_sft.py): change para valid_bsz to valid_micro_num * feat(scheduler/no_pipeline_scheduler.py): update para set _grad_accum_batch_size --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: huangting.p <huangting@sensetime.com> Co-authored-by: yingtongxiong <974106207@qq.com> * feat(*): support no apex (#166) * support no-apex * add default for use_apex * fix lint * modify the RMSNormTorch * remove some comments * remove use_apex parameter * remove some unnecessary code * refactor(*): refactor the code with no-apex (#170) * support no-apex * add default for use_apex * fix lint * modify the RMSNormTorch * remove some comments * remove use_apex parameter * remove some unnecessary code * optimize the code including import * remove the import RMSNorm * remove warnings * refactor(scheduler): rewrite pipeline scheduler (#138) * refactor(scheduler): rewrite pipeline scheduler * fix(*): fix pipeline scheduler bugs * fix(*): fix merge bug * feat(*): update codes with todo tag * feat(*): add comments * feat(internlm/core/scheduler): update recv_prev/next logic * feat(utils/evaluation.py): update sche metric hook for valid --------- Co-authored-by: huangting.p <huangting@sensetime.com> * feat(*): support fp32 training (#155) * support float32 training * fix lint * add adaptation in model/utils.py * remove some unnecessary code * fix lint * feat(optim): add support for fp32 zero * Revert "Merge pull request #2 from SolenoidWGT/fp32_zero" This reverts commitpull/238/head v0.2.053fc50b0e5
, reversing changes made to40f24d0a73
. revert commit * merge develop * Update utils.py * support fp32 in zero optimizer * modify the dtype --------- Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> * feat(*): support sequence_parallel (#180) * support sequence_parallel for no pipeline * sequence_parallel does not support no-flash-attn * support sequence parallel for pipeline * add memory profiler * Update 13B.py * add memory profiler * fix evaluation bug * remove some unnecessary code * remove some unnecessary code * Update parallel_context.py * modify the config * remove memory profiler * modify the config * support selective dropout * feat(monitor): support monitor and alert (#175) * feat(monitor): support monitor and alert * feat(monitor.py): fix demo error * feat(monitor.py): move cmd monitor args to config file * feat(hybrid_zero_optim.py): if overflow occurs send alert msg * feat(monitor.py): remove alert msg filter * feat(monitor.py): optimize class MonitorTracker * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(monitor.py): optimize code * feat(train.py): update print to log * style(ci): fix lint error * fix(utils/evaluation.py): remove useless code * fix(model/modeling_internlm.py): fix lint error --------- Co-authored-by: huangting4201 <huangting3@sensetime.com> * feat(ckpt): add async upload and ckpt snapshot (#161) * use fp16 in instruction (#80) * delete torch_dtype of README's example code (#100) * feat(ckpt): support async ckpt upload and ckpt snapshot --------- Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com> Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> * feat(ckpt): add auto ckpt load and singal quit (#189) Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> * Revert "feat(ckpt): add auto ckpt load and singal quit (#189)" (#192) This reverts commita45a91bb84
. * refactor(solver/optimizer): improve optimizer memory (#193) * refactor(solver/optimizer): improve optimizer memory * feat(data): remove useless dataset type ids map * Feat/optimizer (#194) * feat(optimier.py): reduce memory footprint and avoid _check_overflow call * feat(optimier.py): reduce memory footprint and avoid _check_overflow call * feat(optimizer.py): overlap compute norm with allreduce * update var and function name * update function compute norm (#197) Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> * feat(optimizer/hybrid_zero_optim.py): overlap gradients last bucket allreduce and compute norm (#196) * support gradients allreduce and compute norm overlap * fix para set error * remove timer cal_norm for testing * feat(optimizer/hybrid_zero_optim.py): support group global norm * format(lint): fix lint error * feat(optimizer/store.py): update code based on comment --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> Co-authored-by: huangting4201 <1538303371@qq.com> * fix(ci): fix ci train error (#199) * fix/ci train error (#200) * fix(ci): fix ci train error * fix(ci): fix ci train error * fix(ci): fix ci train error * fix(train.py): fix scheduler metric hook skip error (#204) * Merge main to develop (#203) * fix/fix_submodule_err (#61) * fix/fix_submodule_err --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> * fix issue templates (#65) * fix(tokenizer): refactor tokenizer and update usage in readme (#51) * update tokenizer example * fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (#73) * fix a typo in readme * in order to find InternLMTokenizer, select a lower version of Transformers --------- Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> * [Doc] Add wechat and discord link in readme (#78) * Doc:add wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * [Docs]: add Japanese README (#43) * Add Japanese README * Update README-ja-JP.md replace message * Update README-ja-JP.md * add repetition_penalty in GenerationConfig in web_demo.py (#48) Co-authored-by: YWMditto <862779238@qq.com> * use fp16 in instruction (#80) * [Enchancement] add more options for issue template (#77) * [Enchancement] add more options for issue template * update qustion icon * fix link * Use tempfile for convert2hf.py (#23) Fix https://github.com/InternLM/InternLM/issues/50 * delete torch_dtype of README's example code (#100) * set the value of repetition_penalty to 1.0 to avoid random outputs (#99) * Update web_demo.py (#97) Remove meaningless log. * [Fix]Fix wrong string cutoff in the script for sft text tokenizing (#106) * docs(install.md): update dependency package transformers version to >= 4.28.0 (#124) Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> * docs(LICENSE): add license (#125) * add license of colossalai and flash-attn * fix lint * modify the name * fix AutoModel map in convert2hf.py (#116) * variables are not printly as expect (#114) * feat(solver): fix code to adapt to torch2.0 and provide docker images (#128) * feat(solver): fix code to adapt to torch2.0 * docs(install.md): publish internlm environment image * docs(install.md): update dependency packages version * docs(install.md): update default image --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> * add demo test (#132) Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * fix web_demo cache accelerate (#133) * Doc: add twitter link (#141) * Feat add checkpoint fraction (#151) * feat(config): add checkpoint_fraction into config * feat: remove checkpoint_fraction from configs/7B_sft.py --------- Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> * [Doc] update deployment guide to keep consistency with lmdeploy (#136) * update deployment guide * fix error * use llm partition (#159) Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * test(ci_scripts): clean test data after test, remove unnecessary global variables, and other optimizations (#165) * test: optimization of ci scripts(variables, test data cleaning, etc). * chore(workflows): disable ci job on push. * fix: update partition * test(ci_scripts): add install requirements automaticlly,trigger event about lint check and other optimizations (#174) * add pull_request in lint check * use default variables in ci_scripts * fix format * check and install requirements automaticlly * fix format --------- Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * feat(profiling): add a simple memory profiler (#89) * feat(profiling): add simple memory profiler * feat(profiling): add profiling argument * feat(CI_workflow): Add PR & Issue auto remove workflow (#184) * feat(ci_workflow): Add PR & Issue auto remove workflow Add a workflow for stale PR & Issue auto remove - pr & issue well be labeled as stale for inactive in 7 days - staled PR & Issue well be remove in 7 days - run this workflow every day on 1:30 a.m. * Update stale.yml * feat(bot): Create .owners.yml for Auto Assign (#176) * Create .owners.yml: for issue/pr assign automatically * Update .owners.yml * Update .owners.yml fix typo * [feat]: add pal reasoning script (#163) * [Feat] Add PAL inference script * Update README.md * Update tools/README.md Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update tools/pal_inference.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update pal script * Update README.md * restore .ore-commit-config.yaml * Update tools/README.md Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update tools/README.md Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update pal inference script * Update READMD.md * Update internlm/utils/interface.py Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> * Update pal script * Update pal script * Update script * Add docstring * Update format * Update script * Update script * Update script --------- Co-authored-by: BigDong <yudongwang1226@gmail.com> Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> * test(ci_scripts): add timeout settings and clean work after the slurm job (#185) * restore pr test on develop branch * add mask * add post action to cancel slurm job * remove readonly attribute on job log * add debug info * debug job log * try stdin * use stdin * set default value avoid error * try setting readonly on job log * performance echo * remove debug info * use squeue to check slurm job status * restore the lossed parm * litmit retry times * use exclusive to avoid port already in use * optimize loop body * remove partition * add {} for variables * set env variable for slurm partition --------- Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * refactor(tools): move interface.py and import it to web_demo (#195) * move interface.py and import it to web_demo * typo * fix(ci): fix lint error * fix(ci): fix lint error --------- Co-authored-by: Sun Peng <sunpengsdu@gmail.com> Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> Co-authored-by: Kai Chen <chenkaidev@gmail.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com> Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> Co-authored-by: vansin <msnode@163.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com> Co-authored-by: YWMditto <862779238@qq.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com> Co-authored-by: Shuo Zhang <zhangshuolove@live.com> Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com> Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com> Co-authored-by: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: zachtzy <141206206+zachtzy@users.noreply.github.com> Co-authored-by: cx <759046501@qq.com> Co-authored-by: Jaylin Lee <61487970+APX103@users.noreply.github.com> Co-authored-by: del-zhenwu <dele.zhenwu@gmail.com> Co-authored-by: Shaoyuan Xie <66255889+Daniel-xsy@users.noreply.github.com> Co-authored-by: BigDong <yudongwang1226@gmail.com> Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Co-authored-by: huangting4201 <huangting3@sensetime.com> * fix(pipeline_scheduler.py): fix tensor shape err and comm block (#210) * feat(train.py): support torch profiler (#201) * feat(train.py): support torch profiling * feat(train.py): optimize initialize_llm_profile * feat(train.py): profiling with tp0 and dp0 * move sequence parallel context manager to evalation func * fix lint * move the process for type_ids to load_new_batch * fix lint --------- Co-authored-by: yingtongxiong <974106207@qq.com> * feat(ckpt): add auto ckpt load and singal quit (#216) Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> * feat(memory_profiler): improve memory profiler (#217) * Feat/overlap_bcast_forward (#218) * feat/support bcast forward overlao * feat/optimize the bcast call * feat/optimize the bcast call * feat/optimize the bcast call * fix lint * fix lint * fix lint * fix lint * add torch.cuda.synchronize in save_checkpoint --------- Co-authored-by: sunpeng <sunpengsdu@gmail.com> * fix(*): move sequence_parallel to parallel config (#224) * move sequence_parallel to parallel config * set the sequece_parallel default value is False * fix lint * fix lint * fix lint * Feat/example training internlm (#212) * feat(train/training_internlm.py): move common init funcs to internlm/train * feat(train/training_internlm.py): update some public funcs * feat(train/training_internlm.py): update some public funcs * feat(evaluation.py): adapt evaluate to streaming dataset * feat(train/training_internlm.py): minor update based on comments * fix(training_internlm.py): set train dataloader persistent_workers true only when num_worker>0 * fix(training_internlm.py): fix demo error * feat(data/utils.py): add new dataset type code for streaming dataset (#225) * test(model): support fp32 with flash_attn (#223) * support tf32 with flash * move autocast to attention * fix lint * fix lint * fix lint * fix lint * fix some bugs in model * modify the convert dtype * fix(pipeline): modify the sequence_parallel in pipeline (#227) * move sequence_parallel to parallel config * set the sequece_parallel default value is False * fix lint * fix lint * fix lint * modify the sequence_parallel in pp * feat(init): add skip args check flag and add zero overlap flag (#222) * feat(init): add skip args check flag * fix(optim): add param overlap enable flag * fix(ci): fix train error (#228) Co-authored-by: huangting4201 <huangting3@sensetime.com> * fix(writer): fix tensorboard resume bug (#229) * fix(train.py): fix overflow grad norm error (#230) * feat(ckpt): add train config into ckpt (#231) --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: Sun Peng <sunpengsdu@gmail.com> Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> Co-authored-by: Kai Chen <chenkaidev@gmail.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com> Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> Co-authored-by: vansin <msnode@163.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com> Co-authored-by: YWMditto <862779238@qq.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com> Co-authored-by: Shuo Zhang <zhangshuolove@live.com> Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: huangting.p <huangting@sensetime.com> Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com> Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com> Co-authored-by: yingtongxiong <974106207@qq.com> Co-authored-by: cx <759046501@qq.com> Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> Co-authored-by: huangting4201 <huangting3@sensetime.com> Co-authored-by: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: zachtzy <141206206+zachtzy@users.noreply.github.com> Co-authored-by: Jaylin Lee <61487970+APX103@users.noreply.github.com> Co-authored-by: del-zhenwu <dele.zhenwu@gmail.com> Co-authored-by: Shaoyuan Xie <66255889+Daniel-xsy@users.noreply.github.com> Co-authored-by: BigDong <yudongwang1226@gmail.com> Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
parent
e1cefaef6b
commit
54f85a6e9a
|
@ -1,5 +1,5 @@
|
|||
name: demo-in-readme
|
||||
on:
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
|
@ -110,7 +110,6 @@ jobs:
|
|||
srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --gpus-per-task=2 python ../ci_scripts/model/loaded_as_transformer.py
|
||||
cd ..
|
||||
rm -rf $GITHUB_WORKSPACE/hf_ckpt
|
||||
|
||||
load-chat-model-in-hf:
|
||||
if: ${{ always() }}
|
||||
needs: check-requirements
|
||||
|
|
|
@ -115,6 +115,7 @@ venv.bak/
|
|||
*.pkl
|
||||
*.pkl.json
|
||||
*.log.json
|
||||
*.trace.json
|
||||
docs/modelzoo_statistics.md
|
||||
mmdet/.mim
|
||||
work_dirs/
|
||||
|
@ -142,4 +143,5 @@ core.*
|
|||
|
||||
# Run
|
||||
llm_ckpts
|
||||
memory_trace
|
||||
events.*
|
||||
memory_trace
|
||||
|
|
|
@ -49,5 +49,5 @@ repos:
|
|||
args:
|
||||
[
|
||||
'--rcfile=.pylintrc',
|
||||
'--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703'
|
||||
'--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203'
|
||||
]
|
|
@ -15,6 +15,7 @@ VOCAB_SIZE = 103168
|
|||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||
# LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||
ckpt = dict(
|
||||
enable_save_ckpt=True,
|
||||
# Path to save training ckpt.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER,
|
||||
# Path to continue training ckpt (load model weights and scheduler/context states).
|
||||
|
|
|
@ -5,7 +5,7 @@ set -x
|
|||
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
|
||||
readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40"
|
||||
readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt"
|
||||
expected_num=21
|
||||
expected_num=22
|
||||
exit_code=0
|
||||
|
||||
source ./ci_scripts/common/basic_func.sh
|
||||
|
|
|
@ -5,7 +5,7 @@ set -x
|
|||
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
|
||||
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
|
||||
readonly CKPTS20_OUTPUT="${CKPTS20_PATH}/*.pt"
|
||||
expected_num=21
|
||||
expected_num=22
|
||||
exit_code=0
|
||||
|
||||
source ./ci_scripts/common/basic_func.sh
|
||||
|
|
|
@ -5,7 +5,7 @@ set -x
|
|||
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
|
||||
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
|
||||
readonly CKPTS_OUTPUT="${CKPTS20_PATH}/*.pt"
|
||||
expected_num=21
|
||||
expected_num=22
|
||||
exit_code=0
|
||||
|
||||
source ./ci_scripts/common/basic_func.sh
|
||||
|
|
|
@ -7,31 +7,43 @@ MLP_RATIO = 8 / 3
|
|||
NUM_LAYER = 32
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
# Ckpt folder format:
|
||||
# fs: 'local:/mnt/nfs/XXX'
|
||||
# oss: 'boto3:s3://model_weights/XXX'
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||
|
||||
# boto3 Ckpt folder format:
|
||||
# import os
|
||||
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
||||
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
||||
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
||||
CHECKPOINT_EVERY = 50
|
||||
ckpt = dict(
|
||||
# Path to save training ckpt.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER,
|
||||
# Path to continue training ckpt (load model weights and scheduler/context states).
|
||||
# load_ckpt_folder=LOAD_CKPT_FOLDER,
|
||||
# Path to initialize with given model weights.
|
||||
# load_model_only_folder=MODEL_ONLY_FOLDER,
|
||||
checkpoint_every=50,
|
||||
# Wheter to load optimizer states when continuing training.
|
||||
load_optimizer=True,
|
||||
enable_save_ckpt=False, # enable ckpt save.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
||||
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
||||
checkpoint_every=CHECKPOINT_EVERY,
|
||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
|
||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
VALID_FOLDER = "/path/to/dataset"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
valid_every=50,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=50000,
|
||||
skip_batches="",
|
||||
|
@ -39,6 +51,7 @@ data = dict(
|
|||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
# valid_folder=VALID_FOLDER,
|
||||
)
|
||||
|
||||
grad_scaler = dict(
|
||||
|
@ -62,7 +75,8 @@ grad_scaler = dict(
|
|||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
zero_overlap_communication=True,
|
||||
overlap_sync_grad=True,
|
||||
overlap_sync_param=True,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
|
@ -107,9 +121,11 @@ model = dict(
|
|||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16",
|
||||
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
|
@ -118,11 +134,15 @@ zero1 parallel:
|
|||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||
pipeline parallel: pipeline parallel size, only 1 is accepted currently.
|
||||
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
|
||||
pipeline parallel (dict):
|
||||
1. size: int, the size of pipeline parallel.
|
||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=False,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -174,7 +174,7 @@ parallel = dict(
|
|||
- When `size <= 0`, the size of the zero1 process group is equal to the size of the data parallel process group, so the optimizer state parameters will be split within the data parallel range.
|
||||
- When `size == 1`, zero1 is not used, and all data parallel groups retain the complete optimizer state parameters.
|
||||
- When `size > 1` and `size <= data_parallel_world_size`, the zero1 process group is a subset of the data parallel process group.
|
||||
- pipeline: pipeline parallel size, currently only supports 1, default value is 1
|
||||
- pipeline: pipeline parallel size, default value is 1
|
||||
- tensor: tensor parallel size, usually the number of GPUs per node, default value is 1
|
||||
|
||||
Note: `Data parallel size = Total number of GPUs / Pipeline parallel size / Tensor parallel size`
|
||||
|
|
|
@ -159,7 +159,7 @@ parallel = dict(
|
|||
- 当`size <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配
|
||||
- 当`size == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数
|
||||
- 当`size > 1`且`size <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集
|
||||
- pipeline:流水线并行大小,目前只支持 1,默认值为 1
|
||||
- pipeline:流水线并行大小,默认值为 1
|
||||
- tensor:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1
|
||||
|
||||
注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小`
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
from .p2p import (
|
||||
AsynCommunicator,
|
||||
recv_backward,
|
||||
recv_forward,
|
||||
send_backward,
|
||||
send_backward_and_recv_next_backward_async,
|
||||
send_backward_recv_backward,
|
||||
send_backward_recv_forward,
|
||||
send_forward,
|
||||
send_forward_and_recv_next_forward_async,
|
||||
send_forward_backward_recv_forward_backward,
|
||||
send_forward_recv_backward,
|
||||
send_forward_recv_forward,
|
||||
)
|
||||
from .utils import recv_obj_meta, send_obj_meta
|
||||
|
||||
__all__ = [
|
||||
"send_forward",
|
||||
"send_forward_recv_forward",
|
||||
"send_forward_backward_recv_forward_backward",
|
||||
"send_backward",
|
||||
"send_backward_recv_backward",
|
||||
"send_backward_recv_forward",
|
||||
"send_forward_recv_backward",
|
||||
"recv_backward",
|
||||
"recv_forward",
|
||||
"send_obj_meta",
|
||||
"recv_obj_meta",
|
||||
"send_backward_and_recv_next_backward_async",
|
||||
"send_forward_and_recv_next_forward_async",
|
||||
"AsynCommunicator",
|
||||
]
|
|
@ -0,0 +1,582 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
|
||||
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
|
||||
|
||||
Args:
|
||||
tensor_shape (:class:`torch.Size`): shape of tensor
|
||||
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
|
||||
|
||||
Returns:
|
||||
Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
|
||||
"""
|
||||
if chunk_tensor:
|
||||
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
|
||||
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if tensor_chunk_shape % tensor_parallel_world_size == 0:
|
||||
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
chunk_tensor = False
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
return tensor_chunk_shape, chunk_tensor
|
||||
|
||||
|
||||
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
|
||||
if isinstance(recv_shapes, torch.Size):
|
||||
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
|
||||
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||
return buffer_recv, recv_split
|
||||
buffer_recv = []
|
||||
for recv_shape in recv_shapes:
|
||||
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
|
||||
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv, recv_split
|
||||
|
||||
|
||||
def process_object_to_send(object_send, scatter_gather_tensors):
|
||||
if isinstance(object_send, torch.Tensor):
|
||||
send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
|
||||
if send_split:
|
||||
object_send = split_tensor_into_1d_equal_chunks(object_send)
|
||||
return object_send
|
||||
|
||||
object_send_list = []
|
||||
for tensor_send in object_send:
|
||||
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
|
||||
if send_split:
|
||||
object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
|
||||
else:
|
||||
object_send_list.append(tensor_send)
|
||||
object_send = tuple(object_send_list)
|
||||
|
||||
return object_send
|
||||
|
||||
|
||||
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
op_to_add = dist.P2POp(comm_op, obj, comm_rank)
|
||||
ops_queue.append(op_to_add)
|
||||
else:
|
||||
for tensor_to_comm in obj:
|
||||
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
|
||||
ops_queue.append(op_to_add)
|
||||
|
||||
|
||||
def _communicate(
|
||||
object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||
recv_prev: bool = False,
|
||||
recv_next: bool = False,
|
||||
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
prev_rank: int = None,
|
||||
next_rank: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors: bool = False,
|
||||
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
"""
|
||||
Adapted from megatron.p2p_communication.
|
||||
Communicate tensors between stages. Used as helper method in other
|
||||
communication methods that are used in pipeline schedule.
|
||||
Takes the following arguments:
|
||||
object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank
|
||||
(no tensor sent if set to None).
|
||||
object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank
|
||||
(no tensor sent if set to None).
|
||||
recv_prev (bool): boolean for whether tensor should be received from
|
||||
previous rank.
|
||||
recv_next (bool): boolean for whether tensor should be received from
|
||||
next rank.
|
||||
recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
|
||||
from the previous stage, defualts to None.
|
||||
recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
|
||||
from the next stage, defualts to None.
|
||||
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
|
||||
next_rank (int): the rank of the next pipeline stage, defualts to None,
|
||||
dtype (torch.dtype): data type of intermediate buffers, defaults to None
|
||||
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
|
||||
|
||||
Returns:
|
||||
Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
|
||||
"""
|
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions
|
||||
# if needed.
|
||||
tensor_recv_prev = None
|
||||
tensor_recv_next = None
|
||||
|
||||
if recv_prev:
|
||||
assert recv_prev_shape is not None
|
||||
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
|
||||
recv_prev_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
|
||||
if recv_next:
|
||||
assert recv_next_shape is not None
|
||||
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
|
||||
recv_next_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
|
||||
if object_send_prev is not None or recv_prev:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if object_send_next is not None or recv_next:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if object_send_prev is not None:
|
||||
object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
|
||||
|
||||
if object_send_next is not None:
|
||||
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
|
||||
|
||||
ops = []
|
||||
if object_send_prev is not None:
|
||||
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
|
||||
|
||||
if tensor_recv_next is not None:
|
||||
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
|
||||
|
||||
if object_send_next is not None:
|
||||
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if recv_prev and recv_prev_split:
|
||||
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_prev)):
|
||||
tensor_recv_prev[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
if recv_next and recv_next_split:
|
||||
if isinstance(tensor_recv_next, torch.Tensor):
|
||||
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_next)):
|
||||
tensor_recv_next[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(
|
||||
input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
prev_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
|
||||
"""
|
||||
input_tensor, _ = _communicate(
|
||||
recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(
|
||||
output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
|
||||
"""
|
||||
_, output_tensor_grad = _communicate(
|
||||
recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
|
||||
_communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_forward_recv_backward(
|
||||
output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the gradient tensor from the
|
||||
next stage in pipeline as the input gradient tensor of this stage.
|
||||
|
||||
Args:
|
||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||
"""
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_next=output_grad_shape is not None,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
prev_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the gradient tensor to the
|
||||
previous stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
|
||||
Args:
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||
"""
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=input_tensor_shape is not None,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_forward_recv_forward(
|
||||
output_tensor,
|
||||
input_tensor_shape,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
|
||||
Args:
|
||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||
"""
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_prev=input_tensor_shape is not None,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
output_grad_shape,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the gradient tensor to the
|
||||
previous stage in pipeline, while receives the gradient tensor from the
|
||||
next member in pipeline as the input of this stage.
|
||||
|
||||
Args:
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||
"""
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_next=output_grad_shape is not None,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_backward_recv_forward_backward(
|
||||
output_tensor,
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
output_grad_shape,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
||||
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
||||
next stage and the input tensor from the previous stage.
|
||||
|
||||
Args:
|
||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
|
||||
from the previous.
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
|
||||
from the next.
|
||||
|
||||
Returns:
|
||||
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`,
|
||||
List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
|
||||
"""
|
||||
input_tensor, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=input_tensor_shape is not None,
|
||||
recv_next=output_grad_shape is not None,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor, output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_and_recv_next_forward_async(
|
||||
output_tensor,
|
||||
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors=False,
|
||||
):
|
||||
"""send forward output to next rank and recv forward input from prev rank"""
|
||||
|
||||
reqs = []
|
||||
tensor_recv_prev = None
|
||||
|
||||
# prepare send opreations
|
||||
if output_tensor is not None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors)
|
||||
|
||||
if isinstance(output_tensor, torch.Tensor):
|
||||
reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank))
|
||||
else:
|
||||
for tensor_to_comm in output_tensor:
|
||||
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank))
|
||||
|
||||
# prepare receive opreations
|
||||
if recv_prev_shape is not None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
# create receive buffer
|
||||
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
|
||||
recv_prev_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
# generate async receive opterations
|
||||
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||
reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank))
|
||||
else:
|
||||
for tensor_to_comm in tensor_recv_prev:
|
||||
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank))
|
||||
|
||||
if len(reqs) > 0:
|
||||
reqs = dist.batch_isend_irecv(reqs)
|
||||
|
||||
# return and do other things
|
||||
yield
|
||||
|
||||
# check communication completed
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Process received data
|
||||
if recv_prev_shape is not None and recv_prev_split:
|
||||
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_prev)):
|
||||
tensor_recv_prev[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
yield tensor_recv_prev
|
||||
|
||||
|
||||
def send_backward_and_recv_next_backward_async(
|
||||
input_tensor,
|
||||
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors=False,
|
||||
):
|
||||
reqs = []
|
||||
tensor_recv_next = None
|
||||
|
||||
# prepare send opreations
|
||||
if input_tensor is not None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors)
|
||||
|
||||
if isinstance(input_tensor, torch.Tensor):
|
||||
reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank))
|
||||
else:
|
||||
for tensor_to_comm in input_tensor:
|
||||
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank))
|
||||
|
||||
# prepare receive opreations
|
||||
if recv_next_shape is not None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
# create receive buffer
|
||||
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
|
||||
recv_next_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
# generate async receive opreations
|
||||
if isinstance(tensor_recv_next, torch.Tensor):
|
||||
reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank))
|
||||
else:
|
||||
for tensor_to_comm in tensor_recv_next:
|
||||
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank))
|
||||
|
||||
if len(reqs) > 0:
|
||||
reqs = dist.batch_isend_irecv(reqs)
|
||||
|
||||
# return and do other things
|
||||
yield
|
||||
|
||||
# check communication completed
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Process received data
|
||||
if recv_next_shape is not None and recv_next_split:
|
||||
if isinstance(tensor_recv_next, torch.Tensor):
|
||||
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_next)):
|
||||
tensor_recv_next[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
yield tensor_recv_next
|
||||
|
||||
|
||||
class AsynCommunicator:
|
||||
"""AsynCommunicator for managing async communication."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
|
||||
recv_shape: Union[torch.Size, List[torch.Size]],
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors=False,
|
||||
forward: bool = True,
|
||||
) -> None:
|
||||
self._need_receive = recv_shape is not None
|
||||
|
||||
if forward:
|
||||
self._coroutine = send_forward_and_recv_next_forward_async(
|
||||
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
else:
|
||||
self._coroutine = send_backward_and_recv_next_backward_async(
|
||||
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
|
||||
@property
|
||||
def need_receive(self) -> bool:
|
||||
return self._need_receive
|
||||
|
||||
def start(self) -> None:
|
||||
next(self._coroutine)
|
||||
|
||||
def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
received = next(self._coroutine)
|
||||
self._coroutine.close()
|
||||
|
||||
return received
|
|
@ -0,0 +1,125 @@
|
|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
def send_meta_helper(obj, next_rank, tensor_kwargs):
|
||||
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
|
||||
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
|
||||
dist.send(send_ndims, next_rank)
|
||||
dist.send(send_shape, next_rank)
|
||||
|
||||
|
||||
def send_obj_meta(obj, next_rank=None):
|
||||
"""Sends obj meta information before sending a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be sent before communications. This function
|
||||
synchronizes with :func:`recv_obj_meta`.
|
||||
|
||||
Args:
|
||||
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
|
||||
need_meta (bool, optional): If False, meta information won't be sent.
|
||||
next_rank (int): The rank of the next member in pipeline parallel group.
|
||||
|
||||
Returns:
|
||||
bool: False
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
send_meta_helper(obj, next_rank, tensor_kwargs)
|
||||
else:
|
||||
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
for tensor_to_send in obj:
|
||||
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
||||
|
||||
|
||||
def recv_meta_helper(prev_rank, tensor_kwargs):
|
||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_ndims, prev_rank)
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
||||
dist.recv(recv_shape, prev_rank)
|
||||
return recv_shape
|
||||
|
||||
|
||||
def recv_obj_meta(prev_rank=None) -> torch.Size:
|
||||
"""Receives obj meta information before receiving a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be received before communications. This function
|
||||
synchronizes with :func:`send_obj_meta`.
|
||||
|
||||
Args:
|
||||
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
|
||||
prev_rank (int): The rank of the source of the obj.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape = torch.Size(recv_shape)
|
||||
else:
|
||||
obj_shape = []
|
||||
for _ in range(recv_obj_nums.item()):
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape.append(torch.Size(recv_shape))
|
||||
|
||||
return obj_shape
|
||||
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
|
||||
"""Break a tensor into equal 1D chunks.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): Tensor to be split before communication.
|
||||
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The split tensor
|
||||
"""
|
||||
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR)
|
||||
start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
end_index = start_index + partition_size
|
||||
if new_buffer:
|
||||
data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
||||
data.copy_(tensor.view(-1)[start_index:end_index])
|
||||
else:
|
||||
data = tensor.view(-1)[start_index:end_index]
|
||||
return data
|
||||
|
||||
|
||||
def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Opposite of above function, gather values from model parallel ranks.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The gathered tensor.
|
||||
"""
|
||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
numel = torch.numel(tensor)
|
||||
numel_gathered = world_size * numel
|
||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
||||
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
|
||||
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
return gathered
|
|
@ -464,7 +464,6 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||
if self.pipeline_parallel_size > 1:
|
||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||
|
||||
for initializer in initializers:
|
||||
parallel_setting = initializer.init_dist_group()
|
||||
if isinstance(parallel_setting, list):
|
||||
|
|
|
@ -73,6 +73,17 @@ class NaiveAMPModel(nn.Module):
|
|||
input_ = input_.float()
|
||||
return input_
|
||||
|
||||
def convert_to_fp32(self, out):
|
||||
"""Converts the output to fp32"""
|
||||
if isinstance(out, Tensor):
|
||||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
elif isinstance(out, dict):
|
||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||
|
||||
return out
|
||||
|
||||
def _reduce_module_buffer(self):
|
||||
"""
|
||||
All-reduces the buffers (e.g., running stats of batch normalization) across
|
||||
|
@ -121,10 +132,5 @@ class NaiveAMPModel(nn.Module):
|
|||
out = self.model(*args, **kwargs)
|
||||
|
||||
if self._output_to_fp32:
|
||||
if isinstance(out, Tensor):
|
||||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
elif isinstance(out, dict):
|
||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||
out = self.convert_to_fp32(out)
|
||||
return out
|
||||
|
|
|
@ -1,279 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
|
||||
|
||||
class BaseScheduler(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
It mainly composes of forward_backward_step for gradient backward and
|
||||
optimizer_step for parameters update.
|
||||
For the convenience to enable FP16, we aggregate all codes that contain the
|
||||
control of FP16 in class schedule.
|
||||
|
||||
Args:
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
|
||||
them into data and label.
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None):
|
||||
self.data_process_func = data_process_func
|
||||
|
||||
@abstractmethod
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""To perform actions before running the schedule.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
forward_only (bool): If True, the process won't include backward.
|
||||
return_loss (bool, optional): If False, the loss won't be returned.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine: Engine, inputs: Any):
|
||||
"""Calls the engine with the given inputs.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return engine(inputs)
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
return engine(*inputs)
|
||||
elif isinstance(inputs, dict):
|
||||
return engine(**inputs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
|
||||
"""Calls the engine's criterion with the given outputs and labels.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
|
||||
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
assert isinstance(
|
||||
outputs, (torch.Tensor, list, tuple, dict)
|
||||
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
if isinstance(labels, torch.Tensor):
|
||||
labels = (labels,)
|
||||
|
||||
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
|
||||
return engine.criterion(*outputs, *labels)
|
||||
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
|
||||
return engine.criterion(*outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, dict):
|
||||
return engine.criterion(**outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
||||
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
||||
|
||||
|
||||
class NonPipelineScheduler(BaseScheduler):
|
||||
"""A helper schedule class for no pipeline parallelism running environment.
|
||||
During one process, it loads a batch of dataset and feeds it to the model.
|
||||
After getting the output and calculating the loss, it will use :meth:`step`
|
||||
to update the parameters if it is in training mode.
|
||||
|
||||
Args:
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
|
||||
and returns a tuple in the form of (data, label), and it will be executed in load_batch.
|
||||
gradient_accumulation_steps(int, optional): the steps of gradient accumulation, 1 for disable
|
||||
gradient accumulation.
|
||||
|
||||
Example:
|
||||
# this shows an example of customized data_process_func
|
||||
def data_process_func(dataloader_output):
|
||||
item1, item2, item3 = dataloader_output
|
||||
data = (item1, item2)
|
||||
label = item3
|
||||
return data, label
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
|
||||
# check that non-pipeline schedule data process func only takes in one parameter
|
||||
# which is the batch data
|
||||
if data_process_func:
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(sig.parameters) == 1, (
|
||||
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
|
||||
"which is a tuple of tensors for the current batch, "
|
||||
"i.e. data_process_func(dataloader_output)."
|
||||
)
|
||||
|
||||
self._grad_accum_size = gradient_accumulation_size
|
||||
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
|
||||
self._grad_accum_offset = 0
|
||||
|
||||
super().__init__(data_process_func)
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""Performs actions before running the schedule.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _load_accum_batch(self, data: Any, label: Any):
|
||||
"""Loads a batch of data and label for gradient accumulation.
|
||||
|
||||
Args:
|
||||
data (Any): The data to be loaded.
|
||||
label (Any): The label to be loaded.
|
||||
"""
|
||||
_data = {
|
||||
k: v[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
|
||||
for k, v in data.items()
|
||||
}
|
||||
_label = label[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
|
||||
|
||||
self._grad_accum_offset += self._grad_accum_batch_size
|
||||
|
||||
return _data, _label
|
||||
|
||||
def _train_one_batch(
|
||||
self,
|
||||
data: Any,
|
||||
label: Any,
|
||||
engine: Engine,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
scale_loss: int = 1,
|
||||
):
|
||||
"""Trains one batch of data.
|
||||
|
||||
Args:
|
||||
data (Any): The data to be trained.
|
||||
label (Any): The label for the data.
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
|
||||
be executed.
|
||||
return_loss (bool, optional): Loss will be returned if True.
|
||||
scale_loss (int, optional): The scale factor for the loss.
|
||||
"""
|
||||
|
||||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
output = self._call_engine(engine, data)
|
||||
|
||||
if return_loss:
|
||||
loss = self._call_engine_criterion(engine, output, label)
|
||||
loss /= scale_loss
|
||||
|
||||
# backward
|
||||
if not forward_only:
|
||||
engine.backward(loss)
|
||||
|
||||
if not return_loss:
|
||||
loss = None
|
||||
|
||||
return output, loss
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||
forward_only (bool, optional):
|
||||
If True, the model is run for the forward pass, else back propagation will be executed.
|
||||
return_loss (bool, optional): Loss will be returned if True.
|
||||
return_output_label (bool, optional): Output and label will be returned if True.
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
|
||||
batch_data, batch_size = engine.load_batch(data_iter)
|
||||
|
||||
assert (
|
||||
batch_size == self._grad_accum_size
|
||||
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
|
||||
|
||||
if self.data_process_func:
|
||||
data, label = self.data_process_func(batch_data)
|
||||
else:
|
||||
# if not batch data process func is given,
|
||||
# then we regard the batch data as a simple tuple of (data, label)
|
||||
data, label = batch_data
|
||||
|
||||
loss = 0 if return_loss else None
|
||||
outputs = []
|
||||
labels = []
|
||||
|
||||
# reset accumulation microbatch offset
|
||||
self._grad_accum_offset = 0
|
||||
|
||||
for _current_accum_step in range(self._grad_accum_size):
|
||||
if _current_accum_step == self._grad_accum_size - 1:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
else:
|
||||
engine.optimizer.skip_grad_reduce = True
|
||||
|
||||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
_output, _loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
loss += _loss
|
||||
if return_output_label:
|
||||
outputs.append(_output)
|
||||
labels.append(_label)
|
||||
|
||||
if not return_output_label:
|
||||
outputs, labels = None, None
|
||||
|
||||
return outputs, labels, loss
|
|
@ -0,0 +1,12 @@
|
|||
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
|
||||
from .no_pipeline_scheduler import NonPipelineScheduler
|
||||
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
|
||||
|
||||
__all__ = [
|
||||
"BaseScheduler",
|
||||
"NonPipelineScheduler",
|
||||
"InterleavedPipelineScheduler",
|
||||
"PipelineScheduler",
|
||||
"SchedulerHook",
|
||||
"SchedulerMetricHook",
|
||||
]
|
|
@ -0,0 +1,187 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
|
||||
class BaseScheduler(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
It mainly composes of forward_backward_step for gradient backward and
|
||||
optimizer_step for parameters update.
|
||||
For the convenience to enable FP16, we aggregate all codes that contain the
|
||||
control of FP16 in class schedule.
|
||||
|
||||
Args:
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
|
||||
them into data and label.
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None):
|
||||
self.data_process_func = data_process_func
|
||||
|
||||
@abstractmethod
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""To perform actions before running the schedule.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _load_micro_batch(self, data, label, offset, micro_bsz):
|
||||
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
|
||||
micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()}
|
||||
micro_batch_label = label[offset : offset + micro_bsz]
|
||||
|
||||
return micro_batch_data, micro_batch_label
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
forward_only (bool): If True, the process won't include backward.
|
||||
return_loss (bool, optional): If False, the loss won't be returned.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine: Engine, inputs: Any):
|
||||
"""Calls the engine with the given inputs.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return engine(inputs)
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
return engine(*inputs)
|
||||
elif isinstance(inputs, dict):
|
||||
return engine(**inputs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
|
||||
"""Calls the engine's criterion with the given outputs and labels.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
|
||||
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
assert isinstance(
|
||||
outputs, (torch.Tensor, list, tuple, dict)
|
||||
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
if isinstance(labels, torch.Tensor):
|
||||
labels = (labels,)
|
||||
|
||||
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
|
||||
return engine.criterion(*outputs, *labels)
|
||||
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
|
||||
return engine.criterion(*outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, dict):
|
||||
return engine.criterion(**outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
||||
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
||||
|
||||
|
||||
class SchedulerHook(ABC):
|
||||
"""
|
||||
Scheduler Hook.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
"""Actions before forward"""
|
||||
|
||||
@abstractmethod
|
||||
def after_forward(self, scheduler, outputs) -> None:
|
||||
"""Actions after forward"""
|
||||
|
||||
@abstractmethod
|
||||
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||
"""Actions before criterion"""
|
||||
|
||||
@abstractmethod
|
||||
def after_criterion(self, scheduler, loss) -> None:
|
||||
"""Actions after criterion"""
|
||||
|
||||
@abstractmethod
|
||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||
"""Actions before backward"""
|
||||
|
||||
@abstractmethod
|
||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||
"""Actions after backward"""
|
||||
|
||||
@abstractmethod
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
"""A post helper function"""
|
||||
|
||||
|
||||
class SchedulerMetricHook(SchedulerHook):
|
||||
"""
|
||||
Scheduler Metric Hook.
|
||||
"""
|
||||
|
||||
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
|
||||
self._post_func = metric
|
||||
self._skip = skip
|
||||
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").start()
|
||||
|
||||
def after_forward(self, scheduler, outputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").stop()
|
||||
|
||||
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||
if not self._skip:
|
||||
timer("cal_loss").start()
|
||||
|
||||
def after_criterion(self, scheduler, loss) -> None:
|
||||
if not self._skip:
|
||||
timer("cal_loss").stop()
|
||||
|
||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||
if not self._skip:
|
||||
timer("bwd").start()
|
||||
|
||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||
if not self._skip:
|
||||
timer("bwd").stop()
|
||||
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
if self._post_func is not None:
|
||||
self._post_func(outputs, label)
|
|
@ -0,0 +1,192 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
from typing import Any, Callable, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
|
||||
class NonPipelineScheduler(BaseScheduler):
|
||||
"""A helper schedule class for no pipeline parallelism running environment.
|
||||
During one process, it loads a batch of dataset and feeds it to the model.
|
||||
After getting the output and calculating the loss, it will use :meth:`step`
|
||||
to update the parameters if it is in training mode.
|
||||
|
||||
Args:
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
|
||||
and returns a tuple in the form of (data, label), and it will be executed in load_batch.
|
||||
gradient_accumulation_steps(int, optional): the steps of gradient accumulation, 1 for disable
|
||||
gradient accumulation.
|
||||
|
||||
Example:
|
||||
# this shows an example of customized data_process_func
|
||||
def data_process_func(dataloader_output):
|
||||
item1, item2, item3 = dataloader_output
|
||||
data = (item1, item2)
|
||||
label = item3
|
||||
return data, label
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_process_func: Callable = None,
|
||||
gradient_accumulation_size: int = 1,
|
||||
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
||||
):
|
||||
self._grad_accum_size = gradient_accumulation_size
|
||||
self._grad_accum_offset = 0
|
||||
|
||||
self._hooks = scheduler_hooks
|
||||
|
||||
super().__init__(data_process_func)
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""Performs actions before running the schedule.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
|
||||
for hook in self._hooks:
|
||||
getattr(hook, func_name)(self, *args, **kwargs)
|
||||
|
||||
def _load_accum_batch(self, data: Any, label: Any):
|
||||
"""Loads a batch of data and label for gradient accumulation.
|
||||
|
||||
Args:
|
||||
data (Any): The data to be loaded.
|
||||
label (Any): The label to be loaded.
|
||||
"""
|
||||
|
||||
_data, _label = self._load_micro_batch(
|
||||
data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
|
||||
)
|
||||
self._grad_accum_offset += self._grad_accum_batch_size
|
||||
|
||||
if self.data_process_func:
|
||||
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
|
||||
_label = self.data_process_func(_label, _data["cu_seqlens"])
|
||||
_data.pop("cu_seqlens")
|
||||
_data.pop("indexes")
|
||||
|
||||
return _data, _label
|
||||
|
||||
def _train_one_batch(
|
||||
self,
|
||||
data: Any,
|
||||
label: Any,
|
||||
engine: Engine,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
scale_loss: int = 1,
|
||||
):
|
||||
"""Trains one batch of data.
|
||||
|
||||
Args:
|
||||
data (Any): The data to be trained.
|
||||
label (Any): The label for the data.
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
|
||||
be executed.
|
||||
return_loss (bool, optional): Loss will be returned if True.
|
||||
scale_loss (int, optional): The scale factor for the loss.
|
||||
"""
|
||||
|
||||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
self._call_hooks("before_forward", data)
|
||||
output = self._call_engine(engine, data)
|
||||
self._call_hooks("after_forward", output)
|
||||
|
||||
self._call_hooks("post_helper_func", output, label)
|
||||
|
||||
if return_loss:
|
||||
self._call_hooks("before_criterion", output, label)
|
||||
loss = self._call_engine_criterion(engine, output, label)
|
||||
self._call_hooks("after_criterion", loss)
|
||||
loss /= scale_loss
|
||||
|
||||
# backward
|
||||
if not forward_only:
|
||||
self._call_hooks("before_backward", None, None)
|
||||
engine.backward(loss)
|
||||
self._call_hooks("after_backward", None)
|
||||
|
||||
if not return_loss:
|
||||
loss = None
|
||||
|
||||
return output, loss
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||
forward_only (bool, optional):
|
||||
If True, the model is run for the forward pass, else back propagation will be executed.
|
||||
return_loss (bool, optional): Loss will be returned if True.
|
||||
return_output_label (bool, optional): Output and label will be returned if True.
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
|
||||
batch_data, batch_size = engine.load_batch(data_iter)
|
||||
|
||||
assert (
|
||||
batch_size % self._grad_accum_size == 0
|
||||
), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}"
|
||||
self._grad_accum_batch_size = batch_size // self._grad_accum_size
|
||||
|
||||
data, label = batch_data
|
||||
|
||||
loss = 0 if return_loss else None
|
||||
outputs = []
|
||||
labels = []
|
||||
|
||||
# reset accumulation microbatch offset
|
||||
self._grad_accum_offset = 0
|
||||
|
||||
for _current_accum_step in range(self._grad_accum_size):
|
||||
if _current_accum_step == self._grad_accum_size - 1:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
else:
|
||||
engine.optimizer.skip_grad_reduce = True
|
||||
|
||||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
_output, _loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
loss += _loss
|
||||
if return_output_label:
|
||||
outputs.append(_output)
|
||||
labels.append(_label)
|
||||
|
||||
if not return_output_label:
|
||||
outputs, labels = None, None
|
||||
|
||||
return outputs, labels, loss
|
File diff suppressed because it is too large
Load Diff
|
@ -7,7 +7,12 @@ import json
|
|||
from typing import Iterable, Optional
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler
|
||||
from internlm.core.scheduler import (
|
||||
BaseScheduler,
|
||||
InterleavedPipelineScheduler,
|
||||
NonPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
)
|
||||
|
||||
|
||||
class TrainState:
|
||||
|
@ -33,6 +38,11 @@ class TrainState:
|
|||
# Total step count
|
||||
self.total_steps: int = config.data.total_steps
|
||||
|
||||
# resume tensorboard folder, need load from checkpoint or set manually.
|
||||
self.resume_tb_folder = config.resume_tb_folder
|
||||
|
||||
self.tensorboard_folder = config.tensorboard_folder
|
||||
|
||||
def init_batch_sampler(self, train_dl):
|
||||
# Copy of the batch sampler from the DataLoader
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
|
@ -71,6 +81,9 @@ class TrainState:
|
|||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
|
||||
# resume tensorboard from older tensorboard_folder
|
||||
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"batch_count": self.batch_count,
|
||||
|
@ -78,6 +91,7 @@ class TrainState:
|
|||
"num_consumed_tokens": self.num_consumed_tokens,
|
||||
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
||||
"step_count": self.step_count,
|
||||
"tensorboard_folder": self.tensorboard_folder,
|
||||
}
|
||||
|
||||
|
||||
|
@ -112,8 +126,7 @@ class Trainer:
|
|||
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
||||
self._schedule = schedule
|
||||
|
||||
if self.uses_pipeline:
|
||||
self._schedule.pre_processing(self)
|
||||
self._schedule.pre_processing(self._engine)
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
|
@ -126,7 +139,7 @@ class Trainer:
|
|||
@property
|
||||
def uses_pipeline(self):
|
||||
"""Returns whether the pipeline parallel is used or not."""
|
||||
return False
|
||||
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
|
||||
|
||||
def train(self):
|
||||
self._engine.train()
|
||||
|
|
|
@ -219,11 +219,6 @@ class StaticBatchSampler:
|
|||
assert (
|
||||
batch_size - self.start_bsz
|
||||
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
|
||||
assert (
|
||||
self.start_bsz // micro_bsz >= 4
|
||||
), f"Must have more start samples:`{self.start_bsz}` with micro_bsz:\
|
||||
`{micro_bsz}`, so that the pipeline can run correctly"
|
||||
|
||||
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
|
||||
assert (
|
||||
self.start_bsz % micro_bsz == 0
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from internlm.data.single_dataset import JsonlDataset
|
||||
|
||||
|
||||
def get_dataset_dict(folder, split="valid") -> Dict:
|
||||
"""
|
||||
Return a dictionary of Datasets from a folder containing data files for validation.
|
||||
|
||||
Args:
|
||||
folder (str): The path to the folder containing data files.
|
||||
split (str): The split of the data files to be used, default is "valid".
|
||||
|
||||
Returns:
|
||||
A dictionary containing Datasets for each folder in the given path
|
||||
that contains data files with the specified split.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the given folder does not exist.
|
||||
|
||||
Example:
|
||||
If the given folder is as follows,
|
||||
- data
|
||||
- zhihu
|
||||
- xxx.bin
|
||||
- valid.bin
|
||||
- baike
|
||||
- xxx.bin
|
||||
- valid.bin
|
||||
|
||||
The returned dictionary will be,
|
||||
{
|
||||
'zhihu': Dataset,
|
||||
'baike': Dataset
|
||||
}
|
||||
"""
|
||||
|
||||
assert os.path.exists(folder), f"folder `{folder}` not exists"
|
||||
data_dict = {}
|
||||
|
||||
for root, dirs, files in os.walk(folder, followlinks=True):
|
||||
dirs.sort() # The order is guaranteed, and the newly added data starting with z needs to be ranked behind
|
||||
datasets = []
|
||||
for fn in sorted(files): # Need sorted to ensure that the order is consistent
|
||||
if fn.endswith(".bin") and split in fn:
|
||||
fp = os.path.join(root, fn)
|
||||
ds = JsonlDataset(fp)
|
||||
datasets.append(ds)
|
||||
if datasets:
|
||||
ds = ConcatDataset(datasets=datasets)
|
||||
data_dict[os.path.basename(root)] = ds
|
||||
|
||||
return data_dict
|
|
@ -144,6 +144,48 @@ class PackedDataset(torch.utils.data.Dataset):
|
|||
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
||||
return out
|
||||
|
||||
def cal_pos_unpack(self, index):
|
||||
if index == 0:
|
||||
pre_pos = 0
|
||||
else:
|
||||
pre_pos = index * gpc.config.data["micro_bsz"]
|
||||
|
||||
pos = (index + 1) * gpc.config.data["micro_bsz"]
|
||||
return pre_pos, pos
|
||||
|
||||
def build_unpack(self, index):
|
||||
|
||||
pre_pos, pos = self.cal_pos_unpack(index)
|
||||
|
||||
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
|
||||
|
||||
while pre_pos < pos and pre_pos < len(self.dataset):
|
||||
sample_idx = self.sample_indices[pre_pos]
|
||||
sample = self.dataset[sample_idx]
|
||||
length = min(len(sample["tokens"]), self.max_length_per_sample)
|
||||
chunk = sample["tokens"][0:length]
|
||||
pack.extend(chunk)
|
||||
_labels = deepcopy(chunk)
|
||||
_labels = list(_labels[1:]) + [-100]
|
||||
assert len(_labels) == len(chunk), (_labels, chunk)
|
||||
labels.extend(_labels)
|
||||
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
||||
cu_seqlens.append(cu_seqlens[-1] + len(chunk))
|
||||
indexes.extend(list(range(length)))
|
||||
pre_pos = pre_pos + 1
|
||||
|
||||
if cu_seqlens[-1] != self.packed_length:
|
||||
pack = pack + [0] * (self.packed_length - cu_seqlens[-1])
|
||||
labels = labels + [0] * (self.packed_length - cu_seqlens[-1])
|
||||
type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1])
|
||||
indexes.extend(list(range(self.packed_length - cu_seqlens[-1])))
|
||||
cu_seqlens.append(self.packed_length)
|
||||
|
||||
assert len(pack) == self.packed_length
|
||||
|
||||
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
||||
return out
|
||||
|
||||
def __getitem__(self, item: int) -> Dict:
|
||||
"""Given the index, it returns a dict as
|
||||
{
|
||||
|
@ -154,8 +196,11 @@ class PackedDataset(torch.utils.data.Dataset):
|
|||
}
|
||||
"""
|
||||
|
||||
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
||||
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
||||
if gpc.config.model.use_flash_attn:
|
||||
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
||||
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
||||
|
||||
return self.build_unpack(item)
|
||||
|
||||
|
||||
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5}
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2}
|
||||
|
||||
|
||||
def get_dataset_type_id(path):
|
||||
|
@ -13,3 +17,30 @@ def get_dataset_type_id(path):
|
|||
match_idxes.append(idx)
|
||||
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
||||
return match_idxes[0]
|
||||
|
||||
|
||||
def unpack_data(input_ids, cu_seqlens):
|
||||
"""
|
||||
input_ids: (n, packed_length)
|
||||
Return:
|
||||
output: (batch_size, max_length)
|
||||
"""
|
||||
|
||||
bsz = input_ids.shape[0]
|
||||
|
||||
num_sequence = gpc.config.data["micro_bsz"]
|
||||
|
||||
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
||||
|
||||
for i in range(bsz):
|
||||
output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
||||
cu_seqlens_slice = cu_seqlens[i]
|
||||
for j in range(num_sequence):
|
||||
seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
|
||||
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
|
||||
outputs[i] = output
|
||||
|
||||
if bsz == 1:
|
||||
outputs = outputs.squeeze(0)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
from .initialize_trainer import initialize_trainer
|
||||
from .launch import get_default_parser, launch_from_slurm, launch_from_torch
|
||||
from .launch import (
|
||||
get_default_parser,
|
||||
initialize_distributed_env,
|
||||
launch_from_slurm,
|
||||
launch_from_torch,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_default_parser",
|
||||
"initialize_trainer",
|
||||
"launch_from_slurm",
|
||||
"launch_from_torch",
|
||||
"initialize_distributed_env",
|
||||
]
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
|
||||
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
from torch import nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
@ -11,11 +11,19 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from internlm.core.no_pipeline_scheduler import NonPipelineScheduler
|
||||
from internlm.core.scheduler import (
|
||||
InterleavedPipelineScheduler,
|
||||
NonPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
SchedulerHook,
|
||||
)
|
||||
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
|
||||
from internlm.core.trainer import Trainer
|
||||
from internlm.data.utils import unpack_data
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
||||
from internlm.utils.common import get_current_device
|
||||
|
@ -29,6 +37,7 @@ def initialize_trainer(
|
|||
test_dataloader: Optional[Iterable] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
||||
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
||||
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
@ -59,6 +68,8 @@ def initialize_trainer(
|
|||
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
||||
|
||||
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
||||
if gpc.is_using_pp():
|
||||
gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
|
||||
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
||||
gradient_handlers = []
|
||||
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
||||
|
@ -67,8 +78,50 @@ def initialize_trainer(
|
|||
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
||||
gradient_handlers.append(handler)
|
||||
|
||||
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
||||
# initialize scheduler for trainer
|
||||
scheduler = None
|
||||
if gpc.config.model.use_flash_attn:
|
||||
data_fn = None
|
||||
else:
|
||||
data_fn = unpack_data
|
||||
if gpc.is_using_pp():
|
||||
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
||||
tensor_shape = get_tensor_shape()
|
||||
use_interleaved = (
|
||||
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
|
||||
)
|
||||
scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
|
||||
if use_interleaved:
|
||||
if isinstance(model, nn.Sequential):
|
||||
model = nn.ModuleList([model])
|
||||
|
||||
communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
||||
num_chunks=gpc.config.model.num_chunks,
|
||||
dtype=gpc.config.model["dtype"],
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
communication_overlap=communication_overlap,
|
||||
)
|
||||
else:
|
||||
scheduler = PipelineScheduler(
|
||||
data_process_func=data_fn,
|
||||
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
||||
dtype=gpc.config.model["dtype"],
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
else:
|
||||
scheduler = NonPipelineScheduler(
|
||||
data_process_func=data_fn,
|
||||
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
|
||||
# initialize engine for trainer
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
|
|
|
@ -10,7 +10,9 @@ import torch
|
|||
|
||||
from internlm.core.context import Config
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_master_node
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.storage_manager import init_storage_manager
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -38,7 +40,7 @@ def get_default_parser():
|
|||
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
||||
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
||||
parser.add_argument("--seed", type=int, default=1024)
|
||||
parser.add_argument("--profiling", default=False, action="store_true", help="enable/diable profiling.")
|
||||
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -89,6 +91,12 @@ def args_sanity_check():
|
|||
if "valid_folder" not in data:
|
||||
data._add_item("valid_folder", None)
|
||||
|
||||
if "valid_micro_num" not in data:
|
||||
data._add_item("valid_micro_num", data.micro_num)
|
||||
|
||||
if "valid_every" not in data:
|
||||
data._add_item("valid_every", 0)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
|
||||
logger.info(f"seq_len: {data.seq_len}")
|
||||
|
@ -97,36 +105,104 @@ def args_sanity_check():
|
|||
logger.info(f"packed_length: {data.packed_length}")
|
||||
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
|
||||
logger.info(f"min_length: {data.min_length}")
|
||||
logger.info(f"valid_micro_num: {data.valid_micro_num}")
|
||||
logger.info(f"valid_every: {data.valid_every}")
|
||||
|
||||
# processing the checkpoint config
|
||||
if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0:
|
||||
gpc.config.ckpt._add_item("checkpoint_every", float("inf"))
|
||||
ckpt = gpc.config.ckpt
|
||||
if "enable_save_ckpt" not in ckpt:
|
||||
ckpt._add_item("enable_save_ckpt", False)
|
||||
|
||||
if "load_optimizer" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_optimizer", True)
|
||||
# Saving checkpoint args.
|
||||
if ckpt.enable_save_ckpt:
|
||||
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
|
||||
assert ckpt.checkpoint_every > 0
|
||||
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
|
||||
|
||||
if "save_ckpt_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("save_ckpt_folder", None)
|
||||
if "async_upload" not in ckpt:
|
||||
ckpt._add_item("async_upload", False) # async defalut is False.
|
||||
else:
|
||||
if ckpt.async_upload:
|
||||
assert "save_ckpt_folder" in ckpt
|
||||
if "boto3:" not in ckpt.save_ckpt_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
|
||||
)
|
||||
ckpt.async_upload = False
|
||||
else:
|
||||
if "async_upload_tmp_folder" not in ckpt:
|
||||
ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
||||
|
||||
if "load_ckpt_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_ckpt_folder", None)
|
||||
if not ckpt.async_upload:
|
||||
ckpt._add_item("async_upload_tmp_folder", None)
|
||||
|
||||
if "load_model_only_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_model_only_folder", None)
|
||||
if "snapshot_ckpt_folder" not in ckpt:
|
||||
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
|
||||
|
||||
assert not (
|
||||
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
|
||||
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
|
||||
if "oss_snapshot_freq" not in ckpt:
|
||||
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
|
||||
else:
|
||||
ckpt._add_item("checkpoint_every", float("inf"))
|
||||
ckpt._add_item("oss_snapshot_freq", float("inf"))
|
||||
ckpt._add_item("save_ckpt_folder", None)
|
||||
ckpt._add_item("async_upload", False)
|
||||
ckpt._add_item("async_upload_tmp_folder", None)
|
||||
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||
|
||||
gpc.config.ckpt._add_item(
|
||||
"enable_ckpt", gpc.config.ckpt.save_ckpt_folder is not None and gpc.config.ckpt.checkpoint_every > 0
|
||||
)
|
||||
# Loading checkpoint args.
|
||||
if "load_model_only_folder" not in ckpt:
|
||||
ckpt._add_item("load_model_only_folder", None)
|
||||
|
||||
if "load_ckpt_folder" not in ckpt:
|
||||
ckpt._add_item("load_ckpt_folder", None)
|
||||
|
||||
if "load_optimizer" not in ckpt:
|
||||
ckpt._add_item("load_optimizer", True)
|
||||
|
||||
if "stop_file_path" not in ckpt:
|
||||
ckpt._add_item("stop_file_path", None)
|
||||
|
||||
if "load_given_ckpt" not in ckpt:
|
||||
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
|
||||
# to auto-load latest checkpoint.
|
||||
ckpt._add_item("load_given_ckpt", False)
|
||||
|
||||
if ckpt.load_given_ckpt:
|
||||
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
|
||||
logger.warning(
|
||||
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
||||
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||
)
|
||||
ckpt.load_model_only_folder = None
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
||||
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_ckpt}")
|
||||
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
|
||||
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
|
||||
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
|
||||
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
|
||||
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
|
||||
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
|
||||
|
||||
# initialization storage manager
|
||||
init_storage_manager(ckpt)
|
||||
|
||||
# tensorboard writer config
|
||||
if "enable_tb" not in gpc.config:
|
||||
gpc.config._add_item("enable_tb", True)
|
||||
if "tensorboard_folder" not in gpc.config:
|
||||
gpc.config._add_item(
|
||||
"tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None
|
||||
)
|
||||
if "resume_tb_folder" not in gpc.config:
|
||||
gpc.config._add_item(
|
||||
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
|
||||
)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"tensorboard_folder: {gpc.config.tensorboard_folder}")
|
||||
logger.info(f"resume_tb_folder: {gpc.config.resume_tb_folder}")
|
||||
|
||||
# cudnn
|
||||
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
||||
|
@ -144,12 +220,24 @@ def args_sanity_check():
|
|||
logger.warning("dtype is not set, use torch.float16 by defalut!")
|
||||
model._add_item("dtype", torch.float16)
|
||||
else:
|
||||
if model.dtype == "torch.bfloat16":
|
||||
model.dtype = torch.bfloat16
|
||||
elif model.dtype in ("torch.float16", "torch.half"):
|
||||
model.dtype = torch.float16
|
||||
if gpc.config.model.dtype == "torch.bfloat16":
|
||||
gpc.config.model.dtype = torch.bfloat16
|
||||
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
|
||||
gpc.config.model.dtype = torch.float16
|
||||
elif gpc.config.model.dtype == "torch.float32":
|
||||
gpc.config.model.dtype = torch.float32
|
||||
elif gpc.config.model.dtype == "torch.tf32":
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
gpc.config.model.dtype = torch.float32
|
||||
else:
|
||||
assert model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
|
||||
assert gpc.config.model.dtype in [
|
||||
"torch.float16",
|
||||
"torch.half",
|
||||
"torch.bfloat16",
|
||||
"torch.float32",
|
||||
"torch.tf32",
|
||||
]
|
||||
|
||||
if "checkpoint" in model:
|
||||
if model.checkpoint is True:
|
||||
|
@ -177,6 +265,35 @@ def args_sanity_check():
|
|||
logger.info("+" * 15 + " beta2_scheduler Info " + "+" * 15) # pylint: disable=W1201
|
||||
logger.info(f"beta2_scheduler: {gpc.config.beta2_scheduler}")
|
||||
|
||||
# process the model config
|
||||
if "use_flash_attn" not in gpc.config.model:
|
||||
gpc.config.model._add_item("use_flash_attn", True)
|
||||
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||
else:
|
||||
assert not (
|
||||
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||
), "sequence parallel does not support use_flash_attn=False"
|
||||
|
||||
# feishu webhook address for alerting
|
||||
if "alert_address" not in gpc.config:
|
||||
gpc.config._add_item("alert_address", None)
|
||||
|
||||
optim_ckpt = gpc.config.hybrid_zero_optimizer
|
||||
if "zero_overlap_communication" in optim_ckpt:
|
||||
# Compatible with the old interfaces.
|
||||
optim_ckpt._add_item("overlap_sync_grad", optim_ckpt.zero_overlap_communication)
|
||||
if "overlap_sync_grad" not in optim_ckpt:
|
||||
optim_ckpt._add_item("overlap_sync_grad", False)
|
||||
if "overlap_sync_param" not in optim_ckpt:
|
||||
optim_ckpt._add_item("overlap_sync_param", False)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
|
||||
)
|
||||
|
||||
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
|
@ -223,8 +340,6 @@ def launch(
|
|||
# init process groups for different parallel modes from config
|
||||
gpc.init_parallel_groups()
|
||||
|
||||
args_sanity_check()
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available():
|
||||
# if local rank is not given, calculate automatically
|
||||
|
@ -277,7 +392,11 @@ def launch_from_slurm(
|
|||
)
|
||||
|
||||
|
||||
def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024):
|
||||
def launch_from_torch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
backend: str = "nccl",
|
||||
seed: int = 1024,
|
||||
):
|
||||
"""A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
|
@ -305,3 +424,38 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nc
|
|||
backend=backend,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
def initialize_distributed_env(
|
||||
config: str,
|
||||
launcher: str = "slurm",
|
||||
master_port: int = 8888,
|
||||
seed: int = 1024,
|
||||
args_check=True,
|
||||
):
|
||||
"""
|
||||
Initialize distributed environment for distributed training.
|
||||
|
||||
Args:
|
||||
config (str): Config file path.
|
||||
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
|
||||
master_port (str): The master port for distributed training. 8888 by default.
|
||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||
"""
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if launcher == "torch":
|
||||
launch_from_torch(config=config, seed=seed)
|
||||
elif launcher == "slurm":
|
||||
launch_from_slurm(
|
||||
config=config,
|
||||
host=get_master_node(),
|
||||
port=master_port,
|
||||
seed=seed,
|
||||
)
|
||||
else:
|
||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||
|
||||
if args_check:
|
||||
args_sanity_check()
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
from .embedding import Embedding1D, RotaryEmbedding
|
||||
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
|
||||
from .metrics import AccPerplex
|
||||
from .modeling_internlm import build_model_with_cfg
|
||||
from .multi_head_attention import MHA
|
||||
from .utils import gather_forward_split_backward
|
||||
|
@ -13,6 +14,7 @@ __all__ = [
|
|||
"RotaryEmbedding",
|
||||
"RewardModelLinear",
|
||||
"ScaleColumnParallelLinear",
|
||||
"AccPerplex",
|
||||
"MHA",
|
||||
"gather_forward_split_backward",
|
||||
"build_model_with_cfg",
|
||||
|
|
|
@ -7,13 +7,14 @@ import rotary_emb
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
|
||||
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
|
||||
from torch import Tensor, nn
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
from .utils import gather_forward_split_backward
|
||||
from .utils import gather_forward_split_backward, split_forward_gather_backward
|
||||
|
||||
|
||||
class Embedding1D(nn.Module):
|
||||
|
@ -56,6 +57,9 @@ class Embedding1D(nn.Module):
|
|||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
@ -108,6 +112,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|||
|
||||
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
||||
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
||||
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
|
@ -176,7 +181,15 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, qkv: torch.Tensor, **kwargs):
|
||||
if kwargs.get("indexes", None) is not None:
|
||||
return self._forward(qkv, kwargs.pop("indexes"))
|
||||
if kwargs.get("inference_params", None) is not None:
|
||||
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
|
||||
else:
|
||||
return self._eval_forward(qkv)
|
||||
|
||||
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._update_cos_sin_cache(qkv, indexes)
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
||||
|
@ -189,7 +202,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
self._sin_k_cached[indexes],
|
||||
)
|
||||
|
||||
def eval_forward(self, qkv, seqlen_offset=0):
|
||||
def _eval_forward(self, qkv, seqlen_offset=0):
|
||||
"""
|
||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
||||
token in the batch.
|
||||
|
|
|
@ -5,15 +5,13 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.ops.fused_dense import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
fused_dense_func,
|
||||
)
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import fused_dense_func_torch
|
||||
|
||||
|
||||
class ScaleColumnParallelLinear(nn.Linear):
|
||||
|
@ -40,7 +38,6 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
out_features: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup],
|
||||
bias: bool = True,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
weight_scale: int = 1,
|
||||
|
@ -50,7 +47,6 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
||||
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
def forward(self, input): # pylint: disable=W0622
|
||||
|
@ -61,8 +57,12 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func(
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
||||
return fused_dense_func_torch(
|
||||
input,
|
||||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
|
@ -89,12 +89,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
out_features: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup],
|
||||
bias: bool = True,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
weight_scale: int = 1,
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale)
|
||||
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
|
||||
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||
if bias:
|
||||
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||
|
@ -107,11 +106,37 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func(
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
||||
return fused_dense_func_torch(
|
||||
input,
|
||||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
class ColumnParallelLinearTorch(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
|
||||
return fused_dense_func_torch(
|
||||
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinearTorch(RowParallelLinear):
|
||||
def forward(self, x):
|
||||
"""
|
||||
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
||||
a reduce_scatter of the result.
|
||||
"""
|
||||
out = fused_dense_func_torch(x, self.weight, self.bias)
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
FeedForward.
|
||||
|
@ -143,24 +168,30 @@ class FeedForward(nn.Module):
|
|||
|
||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ColumnParallelLinear(
|
||||
self.w1 = ColumnParallelLinearTorch(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = ColumnParallelLinear(
|
||||
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype
|
||||
self.w2 = ColumnParallelLinearTorch(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = RowParallelLinear(
|
||||
self.w3 = RowParallelLinearTorch(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,263 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
||||
from torch_scatter import scatter
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.parallel import is_no_pp_or_last_stage
|
||||
|
||||
|
||||
class AccPerplex:
|
||||
"""
|
||||
AccPerplex module for calculating model's accuracy and perplexity metrics.
|
||||
|
||||
Args:
|
||||
device: The GPU device.
|
||||
tp_pg: The tensor parallel process group.
|
||||
dp_pg: The data parallel process group.
|
||||
tokenizer: For calculating BPB.
|
||||
dataset_types (List[str]): Various data types that will be used in the current training process,
|
||||
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
|
||||
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
|
||||
"""
|
||||
|
||||
def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
|
||||
self.device = device
|
||||
self.right = torch.Tensor([0]).to(device=device)
|
||||
self.total = torch.Tensor([0]).to(device=device)
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device)
|
||||
self.tp_pg = tp_pg
|
||||
self.dp_pg = dp_pg
|
||||
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
||||
self.tokenizer = tokenizer
|
||||
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
|
||||
self.batch_shift = 0
|
||||
self.type_ids = None
|
||||
if dataset_types is not None:
|
||||
self.dataset_types = dataset_types
|
||||
self.total_type_count = len(dataset_types)
|
||||
self.ds_right = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
|
||||
self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
|
||||
|
||||
self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types)
|
||||
|
||||
def set_current_type_ids(self, type_ids: torch.Tensor):
|
||||
self.batch_shift = 0
|
||||
self.type_ids = type_ids.cuda()
|
||||
|
||||
def __call__(self, logits, labels):
|
||||
return self.update(logits, labels, type_ids=self.type_ids)
|
||||
|
||||
def update(self, logits, labels, type_ids=None):
|
||||
if gpc.config.model.use_flash_attn:
|
||||
micro_bsz = labels.size(0)
|
||||
else:
|
||||
micro_bsz = 1
|
||||
if type_ids is not None:
|
||||
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
|
||||
self.batch_shift += 1
|
||||
self.loss_with_type_id.update(logits, labels, type_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
|
||||
logits = logits.detach().clone()
|
||||
labels = labels.detach().clone()
|
||||
|
||||
if self.tokenizer: # need to calculate bits per bytes
|
||||
sequences = self.tokenizer.decode_ids(labels.tolist())
|
||||
self.total_bytes += sum(map(lambda x: len(x.encode("utf-8")), sequences))
|
||||
|
||||
shift_logits = logits.view(-1, logits.size(-1))
|
||||
shift_labels = labels.view(-1)
|
||||
# There is a shift according to the current rank, because the logits are split
|
||||
pred_shift = self.tp_local_rank * logits.shape[-1]
|
||||
|
||||
logits_max = torch.max(shift_logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=self.tp_pg)
|
||||
# Determine whether the maximum value of the current local tensor is the global maximum value
|
||||
logits_global = logits_max == torch.max(shift_logits, dim=-1)[0]
|
||||
|
||||
corrects = torch.logical_and(
|
||||
(shift_labels == (shift_logits.argmax(dim=-1) + pred_shift)), logits_global
|
||||
).long()
|
||||
mask = shift_labels.ne(-100).long()
|
||||
if hasattr(self, "total_type_count"):
|
||||
ds_acc = scatter(corrects, type_ids, dim=0, reduce="sum")
|
||||
token_num_type = scatter(mask, type_ids, dim=0, reduce="sum")
|
||||
if len(ds_acc) < self.total_type_count:
|
||||
ds_acc = torch.cat([ds_acc, ds_acc.new_zeros(self.total_type_count - len(ds_acc))])
|
||||
token_num_type = torch.cat(
|
||||
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
|
||||
)
|
||||
self.ds_tokens += token_num_type
|
||||
sync_tensor = ds_acc
|
||||
torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||
self.ds_right += sync_tensor.view(-1)
|
||||
|
||||
acc = corrects.sum()
|
||||
torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||
self.right += acc # Masked_fill is not needed here because -100 is not available anyway
|
||||
self.total += mask.sum()
|
||||
|
||||
# Subtract the maximum value.
|
||||
shift_logits = shift_logits.sub(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
# Get the partition's vocab indecies
|
||||
partition_vocab_size = shift_logits.size()[-1]
|
||||
vocab_start_index = partition_vocab_size * self.tp_local_rank
|
||||
vocab_end_index = vocab_start_index + partition_vocab_size
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
target_mask = (shift_labels < vocab_start_index) | (shift_labels >= vocab_end_index)
|
||||
masked_target = shift_labels - vocab_start_index
|
||||
masked_target[target_mask] = 0
|
||||
|
||||
# Get predicted-logits = logits[target].
|
||||
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
||||
logits_2d = shift_logits.view(-1, partition_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
|
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||
predicted_logits = predicted_logits_1d.view_as(shift_labels) # bsz x max_len
|
||||
predicted_logits[target_mask] = 0.0
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
||||
pred_exp_logits = torch.exp(predicted_logits)
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
|
||||
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
||||
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
|
||||
self.total_log_probs += total_log_probs
|
||||
|
||||
def get_metric(self, reset=True):
|
||||
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
if hasattr(self, "total_type_count"):
|
||||
torch.distributed.all_reduce(self.ds_right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
torch.distributed.all_reduce(self.ds_tokens, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
if self.tokenizer:
|
||||
torch.distributed.all_reduce(self.total_bytes, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
|
||||
acc = round((self.right / self.total).item(), 4)
|
||||
perplexity = round(torch.exp(self.total_log_probs / self.total).item(), 4)
|
||||
bits_per_bytes = round((self.total_log_probs / self.total_bytes).item(), 4) if self.tokenizer else 0
|
||||
|
||||
if hasattr(self, "total_type_count"):
|
||||
ds_acc = {}
|
||||
ds_tokens = {}
|
||||
for i in range(self.total_type_count):
|
||||
ds_acc[f"acc/{self.dataset_types[i]}"] = round(
|
||||
(self.ds_right[i].float() / (self.ds_tokens[i].float() + 1e-5)).item(), 4
|
||||
)
|
||||
ds_tokens[f"tokens/{self.dataset_types[i]}"] = self.ds_tokens[i].item()
|
||||
if reset:
|
||||
self.right.fill_(0)
|
||||
self.total.fill_(0)
|
||||
self.total_log_probs.fill_(0)
|
||||
self.total_bytes.fill_(0)
|
||||
if hasattr(self, "total_type_count"):
|
||||
self.ds_right.fill_(0)
|
||||
self.ds_tokens.fill_(0)
|
||||
if self.tokenizer is not None:
|
||||
res = {"acc": acc, "perplexity": perplexity, "BPB": bits_per_bytes}
|
||||
else:
|
||||
res = {"acc": acc, "perplexity": perplexity}
|
||||
if hasattr(self, "total_type_count"):
|
||||
res.update(ds_acc)
|
||||
res.update(ds_tokens)
|
||||
|
||||
loss_res = self.loss_with_type_id.get_metric()
|
||||
res.update(loss_res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class LossWithTypeId:
|
||||
"""
|
||||
Notice the loss value computed here may be not the same with the main info loss,
|
||||
cause loss here is the reduced result of the data parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
|
||||
self.device = device
|
||||
self.dp_pg = dp_pg
|
||||
|
||||
self.loss = torch.Tensor([0.0]).to(device=device)
|
||||
self.token_num = torch.Tensor([0.0]).to(device=device)
|
||||
|
||||
if dataset_types is not None:
|
||||
self.dataset_types = dataset_types
|
||||
self.total_type_count = len(dataset_types)
|
||||
self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
|
||||
self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
|
||||
|
||||
self.loss_fn = FlashCrossEntropyLoss(
|
||||
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
|
||||
)
|
||||
|
||||
def update(self, logits, labels, type_ids=None):
|
||||
with torch.no_grad():
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
logits = logits.contiguous().view(-1, logits.size(-1))
|
||||
labels = labels.contiguous().view(-1)
|
||||
loss_list = self.loss_fn(logits, labels)
|
||||
|
||||
cond = labels != -100
|
||||
real_loss_list = loss_list[cond]
|
||||
self.loss += real_loss_list.sum()
|
||||
self.token_num += real_loss_list.numel()
|
||||
|
||||
if hasattr(self, "total_type_count"):
|
||||
type_ids = type_ids.contiguous().view(-1).to(self.device)
|
||||
real_type_ids = type_ids[cond]
|
||||
|
||||
loss_list_type = scatter(real_loss_list, real_type_ids, dim=0, reduce="sum")
|
||||
token_num_type = scatter(torch.ones_like(real_loss_list), real_type_ids, dim=0, reduce="sum")
|
||||
|
||||
if len(loss_list_type) < self.total_type_count:
|
||||
loss_list_type = torch.cat(
|
||||
[loss_list_type, loss_list_type.new_zeros(self.total_type_count - len(loss_list_type))]
|
||||
)
|
||||
token_num_type = torch.cat(
|
||||
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
|
||||
)
|
||||
self.ds_loss += loss_list_type
|
||||
self.ds_token_num += token_num_type
|
||||
|
||||
def get_metric(self, reset=True):
|
||||
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
if hasattr(self, "total_type_count"):
|
||||
torch.distributed.all_reduce(self.ds_loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
torch.distributed.all_reduce(self.ds_token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
|
||||
loss = round((self.loss / self.token_num).item(), 4)
|
||||
res = {
|
||||
"loss_from_metric": loss,
|
||||
}
|
||||
if hasattr(self, "total_type_count"):
|
||||
ds_loss = {}
|
||||
for i in range(self.total_type_count):
|
||||
ds_loss[f"loss/{self.dataset_types[i]}"] = round((self.ds_loss[i] / self.ds_token_num[i]).item(), 4)
|
||||
res.update(ds_loss)
|
||||
|
||||
if reset:
|
||||
self.loss.fill_(0.0)
|
||||
self.token_num.fill_(0.0)
|
||||
if hasattr(self, "total_type_count"):
|
||||
self.ds_loss.fill_(0.0)
|
||||
self.ds_token_num.fill_(0.0)
|
||||
|
||||
return res
|
|
@ -5,7 +5,6 @@ import math
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from torch import nn
|
||||
|
@ -20,7 +19,7 @@ from internlm.model.linear import (
|
|||
ScaleColumnParallelLinear,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import gather_forward_split_backward
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.utils.checkpoint import activation_checkpoint
|
||||
from internlm.utils.common import filter_kwargs
|
||||
|
@ -30,6 +29,7 @@ from internlm.utils.registry import MODEL_INITIALIZER
|
|||
MODEL_TYPE = "INTERNLM"
|
||||
|
||||
logger = get_logger(__file__)
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
|
||||
|
||||
class PackedFlashBaseLayer1D(nn.Module):
|
||||
|
@ -49,6 +49,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -68,12 +69,14 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
dropout_selective_checkpoint: bool = True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
|
||||
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
|
||||
self.layer_idx = layer_idx
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
self.mixer = MHA(
|
||||
|
@ -86,8 +89,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
layer_idx=layer_idx,
|
||||
rotary_emb_dim=head_dim,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=True,
|
||||
sequence_parallel=False,
|
||||
use_flash_attn=use_flash_attn,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -119,7 +121,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
|
@ -243,6 +245,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -271,6 +274,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
dropout_selective_checkpoint: bool = True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -290,7 +294,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
max_position_embeddings=-1,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
padding_idx=None,
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -317,6 +321,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
@ -331,7 +336,6 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
weight_scale=embed_grad_scale,
|
||||
|
@ -397,9 +401,10 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
# all_parts = partition_uniform_with_embed2(num_layers, pipeline_size, num_chunks)
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"The layer sharding is {all_parts}.")
|
||||
|
||||
models = []
|
||||
|
||||
|
@ -445,6 +450,8 @@ def build_model_with_cfg(
|
|||
dropout_selective_checkpoint=True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
sequence_parallel: bool = False, # pylint: disable=W0613
|
||||
):
|
||||
"""
|
||||
Builde model with config
|
||||
|
@ -474,6 +481,7 @@ def build_model_with_cfg(
|
|||
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
||||
use_scaled_init (bool): Whether to use scaled init. True by default.
|
||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -496,6 +504,7 @@ def build_model_with_cfg(
|
|||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||
|
|
|
@ -12,12 +12,12 @@ from flash_attn.modules.mha import (
|
|||
SelfAttention,
|
||||
_update_kv_cache,
|
||||
)
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.embedding import RotaryEmbedding
|
||||
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
|
||||
|
||||
|
||||
class MHA(nn.Module):
|
||||
|
@ -43,6 +43,7 @@ class MHA(nn.Module):
|
|||
of x will be done before doing the matmul.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -57,8 +58,7 @@ class MHA(nn.Module):
|
|||
layer_idx: int = None,
|
||||
rotary_emb_dim: int = 0,
|
||||
rotary_emb_scale_base: int = 0,
|
||||
use_flash_attn: bool = False,
|
||||
sequence_parallel: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
|
@ -77,12 +77,12 @@ class MHA(nn.Module):
|
|||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||
|
||||
# notice here should change bias=True
|
||||
self.Wqkv = ColumnParallelLinear(
|
||||
self.Wqkv = ColumnParallelLinearTorch(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=True,
|
||||
sequence_parallel=sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
|
||||
|
@ -94,8 +94,12 @@ class MHA(nn.Module):
|
|||
)
|
||||
|
||||
# output projection always have the bias (for now)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs
|
||||
self.out_proj = RowParallelLinearTorch(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
process_group,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
|
@ -107,9 +111,9 @@ class MHA(nn.Module):
|
|||
if kwargs.get("indexes", None) is not None:
|
||||
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
|
||||
else:
|
||||
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params)
|
||||
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
|
||||
|
||||
def _forward(self, x, seqlen=None, inference_params=None):
|
||||
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
||||
|
@ -124,13 +128,17 @@ class MHA(nn.Module):
|
|||
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
if inference_params is None:
|
||||
qkv = self.rotary_emb.eval_forward(qkv)
|
||||
else:
|
||||
qkv = self.rotary_emb.eval_forward(qkv, seqlen_offset=inference_params.sequence_len_offset)
|
||||
kwargs["inference_params"] = inference_params
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
|
||||
if inference_params is None:
|
||||
context = self.inner_attn(qkv)
|
||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||
qkv = qkv.to(torch.bfloat16)
|
||||
context = self.inner_attn(qkv).to(x.dtype)
|
||||
else:
|
||||
context = self.inner_attn(qkv)
|
||||
else:
|
||||
q = qkv[:, :, 0]
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
|
@ -158,10 +166,18 @@ class MHA(nn.Module):
|
|||
"""
|
||||
qkv = self.Wqkv(x) # total x hsz'
|
||||
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
|
||||
qkv = self.rotary_emb(qkv, kwargs.pop("indexes"))
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
kwargs.pop("indexes")
|
||||
|
||||
if inference_params is None:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||
qkv = qkv.to(torch.bfloat16)
|
||||
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
|
||||
else:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
|
||||
else:
|
||||
raise RuntimeError("Not support this right now")
|
||||
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def manual_rms_norm(my_input, normalized_shape, weight, eps):
|
||||
# layer norm should always be calculated in float32
|
||||
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
||||
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
||||
my_input = my_input * torch.rsqrt(variance + eps)
|
||||
|
||||
if weight is None:
|
||||
return my_input
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
my_input = my_input.to(weight.dtype)
|
||||
|
||||
return weight * my_input
|
||||
|
||||
|
||||
class RMSNormTorch(torch.nn.Module):
|
||||
"""A custom PyTorch module for RMS normalization."""
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.weight = Parameter(torch.empty(*normalized_shape))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, _input: torch.Tensor):
|
||||
return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
||||
def extra_repr(self):
|
||||
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)
|
|
@ -1,9 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.ops.fused_dense import FusedDenseFunc
|
||||
from flash_attn.utils.distributed import (
|
||||
all_gather_raw,
|
||||
all_reduce_raw,
|
||||
reduce_scatter_raw,
|
||||
)
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def _split(input_, parallel_mode, dim=-1):
|
||||
|
@ -71,3 +86,124 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|||
|
||||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
|
||||
assert my_input.dtype == grad_output.dtype
|
||||
grad_weight = torch.matmul(grad_output.t(), my_input)
|
||||
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
|
||||
return grad_weight, grad_bias
|
||||
|
||||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFuncTorch(FusedDenseFunc):
|
||||
"""A custom PyTorch module extending FusedDenseFunc."""
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, *args):
|
||||
grad_output = grad_output.contiguous()
|
||||
if ctx.return_residual:
|
||||
(grad_input,) = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight = ctx.saved_tensors
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
else:
|
||||
(weight,) = ctx.saved_tensors
|
||||
total_x = None
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
if ctx.needs_input_grad[0]:
|
||||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_output, weight.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||
else:
|
||||
grad_input = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
# we remove the cuda independence, which is different from flash_attn.
|
||||
grad_weight, grad_bias = linear_bias_wgrad_torch(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||
)
|
||||
else:
|
||||
grad_weight = None
|
||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def fused_dense_func_torch(
|
||||
x: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor] = None,
|
||||
return_residual: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True,
|
||||
):
|
||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
)
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
||||
else:
|
||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split the input and keep only the corresponding chuck to the rank.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(input_):
|
||||
return _split(input_, parallel_mode=None)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode, dim):
|
||||
ctx.mode = parallel_mode
|
||||
ctx.dim = dim
|
||||
return _split(input_, parallel_mode, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
def try_import_RMSNorm():
|
||||
"""
|
||||
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
||||
|
||||
"""
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||
|
||||
return RMSNorm
|
||||
except ModuleNotFoundError:
|
||||
logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||
|
||||
return RMSNorm
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .monitor import initialize_monitor_manager, send_alert_message
|
||||
from .utils import set_env_var
|
||||
|
||||
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]
|
|
@ -0,0 +1,53 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
|
||||
"""
|
||||
Use Feishu robot to send messages with the given webhook.
|
||||
|
||||
Args:
|
||||
webhook (str): The webhook to be used to send message.
|
||||
title (str): The message title.
|
||||
message (str): The message body.
|
||||
|
||||
Returns:
|
||||
The response from the request. Or catch the exception and return None.
|
||||
|
||||
Raises:
|
||||
Exception: An exception rasied by the HTTP post request.
|
||||
|
||||
"""
|
||||
|
||||
headers = {"Content-Type": "application/json;charset=utf-8"}
|
||||
msg_body = {
|
||||
"timestamp": int(time.time()),
|
||||
"msg_type": "post",
|
||||
"content": {
|
||||
"post": {
|
||||
"zh_cn": {
|
||||
"title": title,
|
||||
"content": [
|
||||
[
|
||||
{
|
||||
"tag": "text",
|
||||
"text": message,
|
||||
},
|
||||
],
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
|
||||
res = res.json()
|
||||
print(f"Feishu webhook response: {res}")
|
||||
except Exception as err: # pylint: disable=W0703
|
||||
print(f"HTTP Post error: {err}")
|
||||
res = None
|
||||
|
||||
return res
|
|
@ -0,0 +1,226 @@
|
|||
import os
|
||||
import signal
|
||||
import socket
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from threading import Thread
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.monitor.alert import send_feishu_msg_with_webhook
|
||||
from internlm.utils.common import SingletonMeta
|
||||
|
||||
from .utils import get_job_key, set_env_var
|
||||
|
||||
|
||||
def send_alert_message(address: str = None, title: str = None, message: str = None):
|
||||
"""
|
||||
Send alert messages to the given Feishu webhook address in log rank.
|
||||
|
||||
Args:
|
||||
address (str): The alert address to be used to send message, defaults to None.
|
||||
title (str): The message title, defaults to None.
|
||||
message (str): The message body, defaults to None.
|
||||
"""
|
||||
|
||||
if address is not None and gpc.is_rank_for_log():
|
||||
send_feishu_msg_with_webhook(
|
||||
webhook=address,
|
||||
title=title if title else get_job_key(),
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class MonitorTracker(Thread):
|
||||
"""
|
||||
Track job status and alert to Feishu during job training.
|
||||
|
||||
Args:
|
||||
alert_address (str): The Feishu webhook address for sending alerting messages.
|
||||
check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
|
||||
loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alert_address: str,
|
||||
check_interval: float = 300,
|
||||
loss_spike_limit: float = 1.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.alert_address = alert_address
|
||||
self.check_interval = check_interval
|
||||
self.loss_spike_limit = loss_spike_limit
|
||||
self.last_active_time = -1
|
||||
self.last_loss_value = -1
|
||||
self.stopped = False
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
start the monitor tracker.
|
||||
"""
|
||||
|
||||
while not self.stopped:
|
||||
try:
|
||||
self._check_stuck()
|
||||
self._check_loss_spike()
|
||||
except Exception:
|
||||
continue
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def _check_stuck(self):
|
||||
"""
|
||||
Check training status for potential stuck condition.
|
||||
"""
|
||||
|
||||
new_active_time = -1
|
||||
if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
|
||||
new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
|
||||
if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
|
||||
self._send_alert("Training may be in stuck status, please check it.")
|
||||
self.last_active_time = new_active_time
|
||||
|
||||
def _check_loss_spike(self):
|
||||
"""
|
||||
Check for loss value spikes.
|
||||
"""
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
new_loss_value = -1
|
||||
new_step_id = -1
|
||||
if os.getenv("LOSS") is not None:
|
||||
new_loss_value = os.getenv("LOSS")
|
||||
if os.getenv("STEP_ID") is not None:
|
||||
new_step_id = os.getenv("STEP_ID")
|
||||
|
||||
if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
|
||||
assert int(new_step_id) >= 0
|
||||
self._send_alert(
|
||||
f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
|
||||
f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
|
||||
)
|
||||
|
||||
self.last_loss_value = new_loss_value
|
||||
|
||||
def _send_alert(self, message):
|
||||
"""
|
||||
Send alerting message to the Feishu webhook address.
|
||||
|
||||
Args:
|
||||
message (str): The alerting message to be sent.
|
||||
"""
|
||||
|
||||
send_alert_message(
|
||||
address=self.alert_address,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the monitor tracker.
|
||||
"""
|
||||
|
||||
self.stopped = True
|
||||
|
||||
|
||||
class MonitorManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Monitor Manager for managing monitor thread and monitoring training status.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_spike_limit: float = 1.5) -> None:
|
||||
self.monitor_thread = None
|
||||
self.loss_spike_limit = loss_spike_limit
|
||||
self.last_step_loss = -1
|
||||
|
||||
def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
|
||||
"""Check loss value, if loss spike occurs, send alert message to Feishu."""
|
||||
set_env_var(key="LOSS", value=cur_step_loss)
|
||||
set_env_var(key="STEP_ID", value=step_count)
|
||||
|
||||
if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
|
||||
send_alert_message(
|
||||
address=alert_address,
|
||||
message=(
|
||||
f"Checking step by step: Loss spike may be happened in step {step_count}, "
|
||||
f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
|
||||
),
|
||||
)
|
||||
self.last_step_loss = cur_step_loss
|
||||
|
||||
def monitor_exception(self, alert_address: str = None, excp_info: str = None):
|
||||
"""Catch and format exception information, send alert message to Feishu."""
|
||||
filtered_trace = excp_info.split("\n")[-10:]
|
||||
format_trace = ""
|
||||
for line in filtered_trace:
|
||||
format_trace += "\n" + line
|
||||
send_alert_message(
|
||||
address=alert_address,
|
||||
message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
|
||||
)
|
||||
|
||||
def handle_sigterm(self, alert_address: str = None):
|
||||
"""Catch SIGTERM signal, and send alert message to Feishu."""
|
||||
|
||||
def sigterm_handler(sys_signal, frame):
|
||||
print("receive frame: ", frame)
|
||||
print("receive signal: ", sys_signal)
|
||||
send_alert_message(
|
||||
address=alert_address,
|
||||
message=f"Process received signal {signal} and exited.",
|
||||
)
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
|
||||
def start_monitor(
|
||||
self,
|
||||
job_name: str,
|
||||
alert_address: str,
|
||||
monitor_interval_seconds: int = 300,
|
||||
loss_spike_limit: float = 1.5,
|
||||
):
|
||||
"""
|
||||
Initialize and start monitor thread for checking training job status, loss spike and so on.
|
||||
|
||||
Args:
|
||||
job_name (str): The training job name.
|
||||
alert_address (str): The Feishu webhook address for sending alert messages.
|
||||
monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
|
||||
loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
|
||||
may be occurs, defaults to 1.5.
|
||||
"""
|
||||
|
||||
# initialize some variables for monitoring
|
||||
set_env_var(key="JOB_NAME", value=job_name)
|
||||
|
||||
# start a monitor thread, periodically check the training status
|
||||
self.monitor_thread = MonitorTracker(
|
||||
alert_address=alert_address,
|
||||
check_interval=monitor_interval_seconds,
|
||||
loss_spike_limit=loss_spike_limit,
|
||||
)
|
||||
|
||||
def stop_monitor(self):
|
||||
"""Stop the monitor and alert thread."""
|
||||
if self.monitor_thread is not None:
|
||||
self.monitor_thread.stop()
|
||||
|
||||
|
||||
monitor_manager = MonitorManager()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
|
||||
if alert_address is not None:
|
||||
try:
|
||||
monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
|
||||
monitor_manager.handle_sigterm(alert_address=alert_address)
|
||||
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
|
||||
yield
|
||||
finally:
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
|
||||
)
|
||||
monitor_manager.stop_monitor()
|
||||
else:
|
||||
yield
|
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def now_time():
|
||||
return datetime.now().strftime("%b%d_%H-%M-%S")
|
||||
|
||||
|
||||
def set_env_var(key, value):
|
||||
os.environ[str(key)] = str(value)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = "none"
|
||||
if os.getenv("SLURM_JOB_ID") is not None:
|
||||
job_id = os.getenv("SLURM_JOB_ID")
|
||||
elif os.getenv("K8S_WORKSPACE_ID") is not None:
|
||||
job_id = os.getenv("K8S_WORKSPACE_ID")
|
||||
|
||||
return job_id
|
||||
|
||||
|
||||
def get_job_name():
|
||||
job_name = f"unknown-{now_time()}"
|
||||
if os.getenv("JOB_NAME") is not None:
|
||||
job_name = os.getenv("JOB_NAME")
|
||||
|
||||
return job_name
|
||||
|
||||
|
||||
def get_job_key():
|
||||
return f"{get_job_id()}_{get_job_name()}"
|
|
@ -3,15 +3,15 @@
|
|||
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
|
||||
import amp_C
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
GradientStore,
|
||||
|
@ -20,6 +20,7 @@ from internlm.solver.optimizer.store import (
|
|||
)
|
||||
from internlm.solver.optimizer.utils import (
|
||||
DynamicGradScaler,
|
||||
ParamBcastSyncHandler,
|
||||
flatten,
|
||||
get_grad_accumulate_object,
|
||||
has_inf_or_nan,
|
||||
|
@ -28,33 +29,16 @@ from internlm.solver.optimizer.utils import (
|
|||
split_half_float_double,
|
||||
sync_param,
|
||||
)
|
||||
from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import is_model_parallel_parameter
|
||||
|
||||
from .utils import compute_norm
|
||||
|
||||
inf = math.inf
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
def calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
class BaseOptimizer(Optimizer):
|
||||
"""
|
||||
Base Optimizer.
|
||||
|
@ -105,12 +89,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self,
|
||||
optimizer: Optimizer,
|
||||
cpu_offload=False,
|
||||
overlap_broadcast=False,
|
||||
grad_scal_cfg: Config = None,
|
||||
zero_cfg: Config = None,
|
||||
param_bcast_sync_handler: ParamBcastSyncHandler = None,
|
||||
):
|
||||
# DynamicGradScaler related args
|
||||
initial_scale = grad_scal_cfg.fp16.initial_scale
|
||||
if gpc.config.model.dtype is torch.float32:
|
||||
initial_scale = 1
|
||||
else:
|
||||
initial_scale = grad_scal_cfg.fp16.initial_scale
|
||||
min_scale = grad_scal_cfg.fp16.min_scale
|
||||
growth_interval = grad_scal_cfg.fp16.growth_interval
|
||||
growth_factor = grad_scal_cfg.growth_factor
|
||||
|
@ -119,9 +106,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
max_scale = grad_scal_cfg.max_scale
|
||||
|
||||
# Zero related args
|
||||
overlap_communication = zero_cfg.zero_overlap_communication
|
||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||
clip_grad_norm = zero_cfg.clip_grad_norm
|
||||
self._overlap_sync_grad = zero_cfg.overlap_sync_grad
|
||||
self._overlap_sync_param = zero_cfg.overlap_sync_param
|
||||
|
||||
super().__init__(optim=optimizer)
|
||||
|
||||
|
@ -142,7 +130,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
# self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
|
||||
# gradient scaler
|
||||
|
@ -173,7 +161,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
+ f"zo-{self._zero_local_rank}.pt"
|
||||
)
|
||||
self.params_per_rank_id_dict = []
|
||||
self.overlap_broadcast = overlap_broadcast
|
||||
self._param_bcast_sync_handler = param_bcast_sync_handler
|
||||
if self._overlap_sync_param:
|
||||
assert self._param_bcast_sync_handler is not None
|
||||
self._broadcast_comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._broadcast_comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
@ -195,6 +188,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if len(params) != 0:
|
||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
||||
for param in params:
|
||||
setattr(param, "group_id", group_id)
|
||||
self._param_store.set_param_to_rank(param, rank)
|
||||
|
||||
# move to cpu to make room to create the flat tensor
|
||||
|
@ -240,14 +234,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
||||
self.skip_grad_reduce = False
|
||||
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
if self._overlap_communication:
|
||||
# initialize communication stream for
|
||||
# communication-computation overlapping
|
||||
if self._overlap_sync_grad:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
if self._overlap_communication:
|
||||
if self._overlap_sync_grad:
|
||||
self._attach_reduction_hook()
|
||||
|
||||
@property
|
||||
|
@ -281,8 +277,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
global_id = str(i)
|
||||
for j in range(len(param.size())):
|
||||
global_id = "_".join([global_id, str(param.size()[j])])
|
||||
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
if self._overlap_sync_param:
|
||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||
else:
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
|
@ -313,7 +311,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
|
||||
reduction_func = partial(
|
||||
self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank
|
||||
self._store_and_try_reduce_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
||||
# define hook
|
||||
|
@ -334,7 +334,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank)
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False)
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
|
@ -352,7 +352,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
|
||||
def _reduce_grads_stored_in_bucket(self, reduce_rank=None):
|
||||
def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False):
|
||||
# reduce grads
|
||||
self._reduce_grads_by_rank(
|
||||
reduce_rank=reduce_rank,
|
||||
|
@ -360,30 +360,27 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
|
||||
)
|
||||
|
||||
# use communication stream if overlapping
|
||||
# communication with computation
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = (
|
||||
f"Parameter of size ({param.size()}) has been reduced, "
|
||||
+ "duplicate reduction will lead to arithmetic incorrectness"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if is_param_reduced:
|
||||
msg = (
|
||||
f"Parameter of size ({param.size()}) has been reduced, "
|
||||
+ "duplicate reduction will lead to arithmetic incorrectness"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
|
||||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
if self._param_store.belongs_to_current_rank(param):
|
||||
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
|
||||
else:
|
||||
self._param_store.add_previous_reduced_param(param)
|
||||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
|
||||
|
@ -401,17 +398,17 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
if self._overlap_sync_grad:
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
with torch.cuda.stream(self._comm_stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(
|
||||
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA
|
||||
tensor=flat,
|
||||
dtype=self.dtype,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=ParallelMode.DATA,
|
||||
)
|
||||
|
||||
# update the reduced tensor
|
||||
|
@ -438,6 +435,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
reduction_states = self._param_store.get_param_reduction_states()
|
||||
for tensor, _ in reduction_states.items():
|
||||
reduction_states[tensor] = False
|
||||
self._param_store.reset_reduced_data_for_compute_norm()
|
||||
|
||||
# accumulate gradient
|
||||
avg_gradients = self._grad_store._averaged_gradients
|
||||
|
@ -486,6 +484,30 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# Gradients may not be fully synchronized here.
|
||||
|
||||
def _compute_norm_with_stage(
|
||||
self,
|
||||
group_id: int = 0,
|
||||
last_bucket: bool = False,
|
||||
last_stage: bool = False,
|
||||
previous_norm=None,
|
||||
):
|
||||
# compute norm for gradients that have been reduced
|
||||
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
|
||||
if len(params) == 0:
|
||||
grads = [self.padding_grad]
|
||||
params = [self.padding_tensor]
|
||||
|
||||
if self._clip_grad_norm > 0:
|
||||
# this norm is before scaling, it will be very large
|
||||
norm = compute_norm(
|
||||
gradients=grads,
|
||||
parameters=params,
|
||||
last_stage=last_stage,
|
||||
previous_norm=previous_norm,
|
||||
)
|
||||
|
||||
return norm
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
|
@ -497,88 +519,92 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
"""
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
timer("sync_grad").start()
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
# we need to manually reduce these gradients
|
||||
if not self._overlap_communication:
|
||||
if not self._overlap_sync_grad:
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
if param.grad is not None:
|
||||
self._store_and_try_reduce_grads_by_bucket(param)
|
||||
|
||||
# we need to reduce the gradients left in the communication bucket
|
||||
self._reduce_grads_stored_in_bucket()
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True)
|
||||
|
||||
# compute norm for gradients in the before bucket
|
||||
groups_norms = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
if self._overlap_sync_grad:
|
||||
# grads in the last bucket is reduced
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# compute norm for gradients in the last bucket
|
||||
total_norms = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
total_norms.append(
|
||||
self._compute_norm_with_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_norm=groups_norms[group_id],
|
||||
)
|
||||
)
|
||||
|
||||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
return self._step(closure=closure)
|
||||
return self._step(closure=closure, norms=total_norms)
|
||||
|
||||
def _step(self, closure=None):
|
||||
def _step(self, closure=None, norms=None):
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
# check for overflow
|
||||
found_inf = self._check_overflow()
|
||||
found_inf = False
|
||||
# if there is INF values in grades, compute_norm func would also returns -1
|
||||
# thus, we try to avoid call _check_overflow here
|
||||
# found_inf = self._check_overflow()
|
||||
# Because you may encounter inf when computing norm
|
||||
timer("cal_norm").start()
|
||||
norm_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
parameters = self._param_store.get_fp16_params_by_rank_group(
|
||||
group_id=group_id, rank=self._zero_local_rank
|
||||
)
|
||||
else:
|
||||
# in order to prevent collection communication from hanging,
|
||||
# we need to involve rank that are not assigned parameters in compute_norm(),
|
||||
# so we give them a fp16 vector of 0 values.
|
||||
gradients = [self.padding_grad]
|
||||
parameters = [self.padding_tensor]
|
||||
|
||||
if self._clip_grad_norm > 0:
|
||||
# this norm is before scaling, it will be very large
|
||||
norm_group = compute_norm(
|
||||
gradients=gradients,
|
||||
parameters=parameters,
|
||||
)
|
||||
if norm_group == -1:
|
||||
timer("cal_norm").stop()
|
||||
found_inf = True
|
||||
break
|
||||
norm_groups.append(norm_group)
|
||||
if -1 in norms:
|
||||
found_inf = True
|
||||
|
||||
loss_scale = float(self.loss_scale.item()) # backup
|
||||
self.grad_scaler.update(found_inf)
|
||||
if gpc.config.model.dtype is not torch.float32:
|
||||
self.grad_scaler.update(found_inf)
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("Overflow occurs, please check it.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
message="Overflow occurs, please check it.",
|
||||
)
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, None
|
||||
return False, norms
|
||||
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
global_norm = 0
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if not self.param_group_has_params[group_id]:
|
||||
continue
|
||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
|
||||
# create flat gradient for the flat fp32 params
|
||||
fp16_avg_grads = gradients
|
||||
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
with torch.no_grad():
|
||||
flat_fp16_avg_grads = flatten(gradients)
|
||||
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||
gradients = None # release cuda memory
|
||||
|
||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||
flat_fp16_avg_grads = None # release cuda memory
|
||||
|
||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||
assert (
|
||||
|
@ -588,19 +614,19 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
|
||||
# unscale and clip grads
|
||||
# get the global norm
|
||||
global_norm_groups = []
|
||||
if self._clip_grad_norm > 0:
|
||||
global_norm = sum(norm_groups) ** 0.5
|
||||
for norm in norms:
|
||||
global_norm_groups.append(norm**0.5)
|
||||
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if len(single_grad_partition_groups) != 0:
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
|
||||
if gpc.config.model.dtype is not torch.float32:
|
||||
if len(single_grad_partition_groups) != 0:
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
|
||||
|
||||
timer("cal_norm").stop()
|
||||
# update the parameters
|
||||
timer("step").start()
|
||||
|
||||
|
@ -619,35 +645,40 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
# TODO: support broadcast overlap
|
||||
self.broadcast_params(overlap=False)
|
||||
with torch.cuda.stream(self._broadcast_comm_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
# so synchronization is maintained
|
||||
return True, global_norm / loss_scale
|
||||
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
|
||||
|
||||
def broadcast_params(self, overlap=False):
|
||||
def broadcast_params(self):
|
||||
handles = []
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._zero_world_size):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
|
||||
)
|
||||
handles.append(handle)
|
||||
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank in self.param_group_no_params_ranks[group_id]:
|
||||
continue
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param,
|
||||
src=g_rank,
|
||||
group=gpc.get_group(ParallelMode.ZERO1),
|
||||
async_op=True,
|
||||
)
|
||||
|
||||
if not overlap:
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
else:
|
||||
return handles
|
||||
if self._overlap_sync_param:
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
##################
|
||||
# FP16 Utilities #
|
||||
|
@ -665,22 +696,28 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
|
||||
dist.all_reduce(
|
||||
self._found_overflow,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.GLOBAL),
|
||||
)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm, loss_scale):
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = loss_scale
|
||||
combined_scale_groups = []
|
||||
|
||||
if self._clip_grad_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1.0:
|
||||
combined_scale = clip * loss_scale
|
||||
for group_id, total_norm in enumerate(total_norm_groups):
|
||||
combined_scale_groups.append(loss_scale)
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1.0:
|
||||
combined_scale_groups[group_id] = clip * loss_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1.0 / combined_scale)
|
||||
for group_id, grad in enumerate(grad_groups_flat):
|
||||
grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||
|
||||
def clip_grad_norm(self, model, max_norm):
|
||||
# will conduct in the step()
|
||||
|
@ -733,87 +770,3 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
if "zero_devide_optim_plan" in states:
|
||||
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
|
||||
|
||||
|
||||
def compute_norm(gradients, parameters, norm_type=2):
|
||||
"""Get the norm
|
||||
Arguments:
|
||||
gradients (Iterable[Tensor]): The gradient value.
|
||||
parameters (Iterable[Tensor]): The parameter each gradient corresponds to.
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the parameters, need total_norm**(1/norm) before using.
|
||||
"""
|
||||
|
||||
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
||||
# Norm parameters.
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
for g, p in zip(gradients, parameters):
|
||||
# TODO: consider the pipeline shared parameter
|
||||
if (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) == 0
|
||||
): # if shared between different pipe, only count o
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) != 0
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.TENSOR)
|
||||
and not is_model_parallel_parameter(p)
|
||||
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
): # if not used in each chunk, such as layernorm
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif is_model_parallel_parameter(p):
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Should not arrive here")
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
|
||||
|
||||
# If norm is type of float, then we convert them into torch.Tensor.
|
||||
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
|
||||
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
|
||||
if not enable_cuda_kernels:
|
||||
tensor_parallel_norm = move_norm_to_cuda(tensor_parallel_norm)
|
||||
|
||||
total_norm = tensor_parallel_norm
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
if total_norm == float("inf") or total_norm == -float("inf"):
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
|
|
@ -152,6 +152,11 @@ class ParameterStore(BaseStore):
|
|||
self._is_param_reduced = dict()
|
||||
self._reduced_param = []
|
||||
|
||||
self._former_bucket_reduced_param = {}
|
||||
self._last_bucket_reduced_param = {}
|
||||
self._former_bucket_reduced_grad = {}
|
||||
self._last_bucket_reduced_grad = {}
|
||||
|
||||
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
||||
"""
|
||||
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
||||
|
@ -223,6 +228,39 @@ class ParameterStore(BaseStore):
|
|||
def add_previous_reduced_param(self, tensor):
|
||||
self._reduced_param.append(tensor)
|
||||
|
||||
def add_reduced_param_for_compute_norm(self, param, last_bucket=False):
|
||||
group_id = getattr(param, "group_id")
|
||||
if last_bucket:
|
||||
if group_id not in self._last_bucket_reduced_param:
|
||||
self._last_bucket_reduced_param[group_id] = []
|
||||
self._last_bucket_reduced_grad[group_id] = []
|
||||
|
||||
self._last_bucket_reduced_param[group_id].append(param)
|
||||
self._last_bucket_reduced_grad[group_id].append(param.grad)
|
||||
else:
|
||||
if group_id not in self._former_bucket_reduced_param:
|
||||
self._former_bucket_reduced_param[group_id] = []
|
||||
self._former_bucket_reduced_grad[group_id] = []
|
||||
|
||||
self._former_bucket_reduced_param[group_id].append(param)
|
||||
self._former_bucket_reduced_grad[group_id].append(param.grad)
|
||||
|
||||
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
|
||||
if not last_bucket:
|
||||
if group_id not in self._former_bucket_reduced_param:
|
||||
return [], []
|
||||
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
||||
else:
|
||||
if group_id not in self._last_bucket_reduced_param:
|
||||
return [], []
|
||||
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
||||
|
||||
def reset_reduced_data_for_compute_norm(self):
|
||||
self._former_bucket_reduced_param = {}
|
||||
self._last_bucket_reduced_param = {}
|
||||
self._former_bucket_reduced_grad = {}
|
||||
self._last_bucket_reduced_grad = {}
|
||||
|
||||
def clear_grads_of_previous_reduced_params(self):
|
||||
if len(self._reduced_param) > 0:
|
||||
for param in self._reduced_param:
|
||||
|
|
|
@ -1,20 +1,37 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.parallel import is_model_parallel_parameter
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
try:
|
||||
import amp_C
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
APEX_AVAILABLE = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
||||
APEX_AVAILABLE = False
|
||||
|
||||
inf = math.inf
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
return _flatten_dense_tensors(input_)
|
||||
|
@ -46,12 +63,19 @@ def get_grad_accumulate_object(tensor):
|
|||
|
||||
|
||||
def split_half_float_double(tensor_list):
|
||||
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
||||
buckets = []
|
||||
for _, dtype in enumerate(dtypes):
|
||||
bucket = [t for t in tensor_list if t.type() == dtype]
|
||||
if bucket:
|
||||
buckets.append(bucket)
|
||||
dtype_buckets = {
|
||||
"torch.cuda.HalfTensor": [],
|
||||
"torch.cuda.FloatTensor": [],
|
||||
"torch.cuda.DoubleTensor": [],
|
||||
"torch.cuda.BFloat16Tensor": [],
|
||||
}
|
||||
|
||||
for t in tensor_list:
|
||||
dtype = t.type()
|
||||
if dtype in dtype_buckets:
|
||||
dtype_buckets[dtype].append(t)
|
||||
|
||||
buckets = [bucket for bucket in dtype_buckets.values() if bucket]
|
||||
return buckets
|
||||
|
||||
|
||||
|
@ -150,6 +174,149 @@ def sync_param(flat_tensor, tensor_list):
|
|||
p.data = q.data
|
||||
|
||||
|
||||
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
||||
# Convert tensor_list elements to torch.float32
|
||||
tensor_list = [tensor.float() for tensor in tensor_list]
|
||||
norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list])
|
||||
l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0)
|
||||
|
||||
if per_tensor:
|
||||
per_tensor_norm = norms_tensor
|
||||
else:
|
||||
per_tensor_norm = torch.Tensor([]).to(norms_tensor.device)
|
||||
|
||||
return l2_norm, per_tensor_norm
|
||||
|
||||
|
||||
def calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
if APEX_AVAILABLE:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False, # no per-parameter norm
|
||||
)
|
||||
else:
|
||||
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||
return norm
|
||||
|
||||
|
||||
def calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2):
|
||||
"""Get the norm
|
||||
Arguments:
|
||||
gradients (Iterable[Tensor]): The gradient value.
|
||||
parameters (Iterable[Tensor]): The parameter each gradient corresponds to.
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the parameters, need total_norm**(1/norm) before using.
|
||||
"""
|
||||
|
||||
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
||||
# Norm parameters.
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
|
||||
|
||||
if last_stage is False:
|
||||
return total_norm_cuda
|
||||
|
||||
if previous_norm is not None:
|
||||
total_norm_cuda = max(total_norm_cuda, previous_norm)
|
||||
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(
|
||||
total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
for g, p in zip(gradients, parameters):
|
||||
# TODO: consider the pipeline shared parameter
|
||||
if (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) == 0
|
||||
): # if shared between different pipe, only count o
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) != 0
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.TENSOR)
|
||||
and not is_model_parallel_parameter(p)
|
||||
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
): # if not used in each chunk, such as layernorm
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif is_model_parallel_parameter(p):
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Should not arrive here")
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
|
||||
|
||||
# If norm is type of float, then we convert them into torch.Tensor.
|
||||
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
|
||||
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
|
||||
if not enable_cuda_kernels:
|
||||
tensor_parallel_norm = move_norm_to_cuda(tensor_parallel_norm)
|
||||
|
||||
total_norm = tensor_parallel_norm
|
||||
|
||||
if last_stage is False:
|
||||
return total_norm
|
||||
|
||||
if previous_norm is not None:
|
||||
total_norm = total_norm + previous_norm
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(
|
||||
total_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
)
|
||||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
if total_norm == float("inf") or total_norm == -float("inf"):
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
class BaseGradScaler(ABC):
|
||||
"""A base class for the gradient scaler.
|
||||
|
||||
|
@ -313,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler):
|
|||
self._scale = self._scale.fill_(state_dict["_scale"])
|
||||
self._growth_step = state_dict["_growth_step"]
|
||||
self._hysteresis_step = state_dict["_hysteresis_step"]
|
||||
|
||||
|
||||
class ParamBcastSyncHandler:
|
||||
"""
|
||||
Model Partition Handler for overlap broadcast with forward
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None:
|
||||
self._block_to_param = OrderedDict() # <key: nn.Module> <value: list(param)>
|
||||
self._param_to_rank = dict() # <key: param> <value: rank)>
|
||||
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
|
||||
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
|
||||
|
||||
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
total_param_num = sum(p.numel() for p in model.parameters())
|
||||
avg_param_num = total_param_num * 1.0 // zero1_size
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
# record the parameters to transformer/embeding/head/norm block
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _, children in _chunk.named_children():
|
||||
# should be the transformer block definaton in modeling_xxx.py
|
||||
if isinstance(children, nn.ModuleList):
|
||||
# record the block that a parameter belongs to
|
||||
for _, block in enumerate(children):
|
||||
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
||||
self._block_to_param[block] = list(block.parameters())
|
||||
else:
|
||||
# record the block that a parameter belongs to
|
||||
# self._block_to_param[name] = list(children.parameters())
|
||||
self._block_to_param[children] = list(children.parameters())
|
||||
|
||||
alloc_num = 0
|
||||
rank_to_go = 0
|
||||
|
||||
# process the parameters in block_to_param sequencially,
|
||||
# allocate each parameter to a local rank of ParallelMode.ZERO1,
|
||||
# NOTE that we do NOT consider following scenarios:
|
||||
# 1) whether a parameter is trainable;
|
||||
# 2) paramters maybe in different optimizer group
|
||||
for block, params in self._block_to_param.items():
|
||||
# allocate a model block to a local rank of ParallelMode.ZERO1
|
||||
self._block_to_rank[block] = [rank_to_go]
|
||||
for p in params:
|
||||
alloc_num = alloc_num + p.numel()
|
||||
# in this case, allocate the param to next rank if possible
|
||||
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
|
||||
rank_to_go = rank_to_go + 1
|
||||
alloc_num = 0
|
||||
self._block_to_rank[block].append(rank_to_go)
|
||||
# allocate a parameter to a local rank of ParallelMode.ZERO1
|
||||
self._param_to_rank[p] = rank_to_go
|
||||
|
||||
# initialize an empty list for _bcast_handles of each rank
|
||||
for rank in range(gpc.get_world_size(ParallelMode.ZERO1)):
|
||||
self._bcast_handles[rank] = []
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
self._register_sync_parameters_hook()
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
bcast_handles = []
|
||||
# gather all required broadcast hanles into a list
|
||||
for rank in self._block_to_rank[model]:
|
||||
bcast_handles.extend(self._bcast_handles[rank])
|
||||
# need to clear _bcast_handles since they would be processed later
|
||||
self._bcast_handles[rank] = []
|
||||
# wait all required broadcast handles to be completed
|
||||
for handle in bcast_handles:
|
||||
handle.wait()
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
for block, _ in self._block_to_rank.items():
|
||||
block.register_forward_pre_hook(partial(_pre_forward_hook))
|
||||
|
||||
def get_rank_by_param(self, param) -> int:
|
||||
return self._param_to_rank[param]
|
||||
|
||||
def add_bcast_handle(self, rank, handle) -> None:
|
||||
self._bcast_handles[rank].append(handle)
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
from .training_internlm import (
|
||||
get_train_data_loader,
|
||||
get_validation_data_loader,
|
||||
initialize_llm_profile,
|
||||
initialize_model,
|
||||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_train_data_loader",
|
||||
"get_validation_data_loader",
|
||||
"initialize_llm_profile",
|
||||
"initialize_model",
|
||||
"initialize_optimizer",
|
||||
"load_new_batch",
|
||||
"record_current_batch_training_metrics",
|
||||
]
|
|
@ -0,0 +1,414 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
|
||||
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
|
||||
from internlm.data.dataset import get_dataset_dict
|
||||
from internlm.data.dummy_dataset import RandomDataset
|
||||
from internlm.data.packed_dataset import (
|
||||
PackedDataset,
|
||||
PackedDatasetWithoutCuSeqlen,
|
||||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.monitor import set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import (
|
||||
is_no_pp_or_last_stage,
|
||||
sync_model_param,
|
||||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def initialize_model():
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
Returns: The neural network model to be trained or evaluated.
|
||||
"""
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
||||
if isinstance(model, nn.ModuleList):
|
||||
model = nn.ModuleList(
|
||||
[
|
||||
NaiveAMPModel(
|
||||
model=_m,
|
||||
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
for _m in model
|
||||
]
|
||||
)
|
||||
else:
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=is_no_pp_or_last_stage(),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
|
||||
# This sync is very important, cause the model weights kept in optimizer are copied
|
||||
# from the origin parameters in the memory, so we should make sure the dp sync
|
||||
# does not influence the model weights in optimizer be different with the origin parameters.
|
||||
sync_model_param(model, parallel_mode=ParallelMode.DATA)
|
||||
|
||||
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
|
||||
# the same across tensor parallelism.
|
||||
sync_model_param_within_tp(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Your model instance to be trained or evaluated.
|
||||
|
||||
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
if gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||
param_bcast_sync_handler = ParamBcastSyncHandler(model)
|
||||
else:
|
||||
param_bcast_sync_handler = None
|
||||
|
||||
adam_cfg = gpc.config.adam
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
|
||||
lr=adam_cfg.lr,
|
||||
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
|
||||
eps=adam_cfg.adam_eps,
|
||||
)
|
||||
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||
param_bcast_sync_handler=param_bcast_sync_handler,
|
||||
)
|
||||
|
||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||
|
||||
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
|
||||
|
||||
return optimizer, beta2_scheduler, lr_scheduler
|
||||
|
||||
|
||||
def get_train_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
|
||||
):
|
||||
"""
|
||||
Generate and return the training data loader.
|
||||
|
||||
Returns: A tuple of (train_dl, dataset_types).
|
||||
"""
|
||||
|
||||
# Get the dataset types
|
||||
dataset_types = None
|
||||
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
# Get the sample weight dictionary
|
||||
train_folder = data_cfg.train_folder
|
||||
|
||||
if not train_folder:
|
||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||
if data_cfg.pack_sample_into_one:
|
||||
train_ds = PackedDatasetWithoutCuSeqlen(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = PackedDataset(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
if dataset_generate_func is not None:
|
||||
train_ds = dataset_generate_func()
|
||||
else:
|
||||
train_ds = get_packed_dataset_without_short_length(
|
||||
folder=data_cfg.train_folder,
|
||||
packed_length=data_cfg.packed_length,
|
||||
max_length_per_sample=data_cfg.seq_len,
|
||||
show_progress=dist.get_rank() == 0,
|
||||
min_length=data_cfg.min_length,
|
||||
min_length_dict=data_cfg.get("min_length_dict", {}),
|
||||
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
||||
)
|
||||
|
||||
if dataset_generate_func is None or not train_folder:
|
||||
# partition already completed
|
||||
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
|
||||
# Create the training dataset sampler
|
||||
train_sampler = StaticBatchSampler(
|
||||
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
|
||||
batch_size=data_cfg.micro_num,
|
||||
rampup_batch_size=data_cfg.rampup_batch_size,
|
||||
micro_bsz=data_cfg.micro_bsz,
|
||||
seed=1024,
|
||||
drop_last=True,
|
||||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
|
||||
if dataset_generate_func is None or not train_folder:
|
||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||
|
||||
# Create the training data loader
|
||||
train_dl = DataLoader(
|
||||
dataset=train_ds,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=num_worker,
|
||||
pin_memory=True,
|
||||
collate_fn=train_collate_fn,
|
||||
persistent_workers=num_worker > 0,
|
||||
)
|
||||
|
||||
return train_dl, dataset_types
|
||||
|
||||
|
||||
def get_validation_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
|
||||
):
|
||||
"""Generate and return the validation data loader."""
|
||||
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
if not data_cfg.valid_folder:
|
||||
val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len)
|
||||
else:
|
||||
if dataset_generate_func is not None:
|
||||
assert val_collate_fn and dataloader_func is not None
|
||||
val_ds = dataset_generate_func()
|
||||
else:
|
||||
val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="")
|
||||
|
||||
if not isinstance(val_ds, dict):
|
||||
val_ds = {"val": val_ds}
|
||||
|
||||
if val_collate_fn is None or not data_cfg.valid_folder:
|
||||
val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len)
|
||||
|
||||
val_dls = {}
|
||||
for val_name, ds in val_ds.items():
|
||||
if dataloader_func and data_cfg.valid_folder is not None:
|
||||
val_dls[val_name] = dataloader_func(dataset=ds, collate_fn=val_collate_fn)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"load validation dataset {val_name} with valid batch size {str(data_cfg.valid_micro_num)} and "
|
||||
f"{ds.size} Byte samples."
|
||||
)
|
||||
else:
|
||||
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
|
||||
# otherwise too much data may be dropped
|
||||
batch_size = min(
|
||||
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
|
||||
)
|
||||
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
|
||||
|
||||
if batch_size == 0 and gpc.is_rank_for_log():
|
||||
logger.info(f"skip validate {val_name}.")
|
||||
continue
|
||||
|
||||
val_dls[val_name] = get_dpsampler_dataloader(
|
||||
ds,
|
||||
shuffle=False,
|
||||
num_workers=num_worker,
|
||||
batch_size=batch_size,
|
||||
collate_fn=val_collate_fn,
|
||||
drop_last=True,
|
||||
) # drop_last=True, otherwise it may cause problems in the last batch
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
|
||||
f"samples {str(len(val_dls[val_name]))}."
|
||||
)
|
||||
|
||||
return val_dls
|
||||
|
||||
|
||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
|
||||
"""
|
||||
Load and return the new batch data based on training data loader.
|
||||
|
||||
Args:
|
||||
train_dl (torch.utils.data.DataLoader): Dataloader for training.
|
||||
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
train_state (TrainState): Current training state.
|
||||
|
||||
Returns: A batch data and the updated train_iter.
|
||||
"""
|
||||
|
||||
timer("batch-gen").start()
|
||||
try:
|
||||
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
|
||||
if hasattr(train_state, "batch_sampler_iter"):
|
||||
next(train_state.batch_sampler_iter)
|
||||
except StopIteration:
|
||||
train_iter = iter(train_dl)
|
||||
batch = next(train_iter)
|
||||
train_state.num_consumed_samples_in_epoch = 0
|
||||
if hasattr(train_state, "batch_sampler"):
|
||||
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
|
||||
next(train_state.batch_sampler_iter)
|
||||
timer("batch-gen").stop()
|
||||
|
||||
if batch[0].get("type_ids", None) is not None:
|
||||
# if use_flash_attn is False, we need to unpack type_ids
|
||||
if not gpc.config.model.use_flash_attn:
|
||||
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
|
||||
|
||||
return batch, train_iter
|
||||
|
||||
|
||||
def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
||||
"""Initialize and return the profiler context manager instance."""
|
||||
|
||||
if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
llm_profile = torch.profiler.profile
|
||||
logger.info(f"Do profiling in rank {gpc.get_global_rank()}!")
|
||||
else:
|
||||
llm_profile = DummyProfile
|
||||
|
||||
return llm_profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||
schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
f"{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
|
||||
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
||||
+ f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
|
||||
),
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
)
|
||||
|
||||
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
writer,
|
||||
success_update,
|
||||
batch_count,
|
||||
batch,
|
||||
train_state,
|
||||
optimizer,
|
||||
beta2_scheduler,
|
||||
trainer,
|
||||
start_time,
|
||||
loss,
|
||||
grad_norm,
|
||||
metric,
|
||||
update_panel,
|
||||
):
|
||||
"""
|
||||
Print some training metrics of current batch.
|
||||
"""
|
||||
|
||||
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
|
||||
|
||||
if success_update in (0, True):
|
||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||
if is_no_pp_or_last_stage():
|
||||
acc_perplex = metric.get_metric()
|
||||
|
||||
if success_update and gpc.is_rank_for_log():
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
if hasattr(trainer.engine.optimizer, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
|
||||
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
|
||||
|
||||
num_tokens_in_batch = batch[1].nelement()
|
||||
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
|
||||
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
|
||||
tk_per_gpu = 0
|
||||
tk_per_gpu = round(
|
||||
num_tokens_in_batch
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
/ gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
/ (time.time() - start_time),
|
||||
2,
|
||||
)
|
||||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
|
||||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item(),
|
||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"lr": lr,
|
||||
"loss_scale": scaler,
|
||||
"grad_norm": grad_norm,
|
||||
}
|
||||
|
||||
infos["micro_num"] = len(batch[1])
|
||||
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
|
||||
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
|
||||
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
|
||||
infos["largest_length"] = max_length_in_batch # the longest input
|
||||
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
|
||||
infos["smallest_batch"] = min_samples_in_batch
|
||||
infos["adam_beta2"] = beta2_scheduler.get_beta2()
|
||||
|
||||
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
|
||||
infos["fwd_bwd_time"] = fwd_bwd_time
|
||||
|
||||
for key, value in acc_perplex.items():
|
||||
infos[key] = value
|
||||
|
||||
line = ""
|
||||
for key, value in infos.items():
|
||||
line += f"{key}={value} "
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
if update_panel:
|
||||
logger.info(
|
||||
line,
|
||||
extra={
|
||||
"step": batch_count,
|
||||
"lr": lr,
|
||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"grad_norm": grad_norm,
|
||||
"loss": loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(line)
|
||||
|
||||
# if loss spike occurs, send alert info to feishu
|
||||
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
|
|
@ -34,18 +34,6 @@ def get_master_node():
|
|||
return result
|
||||
|
||||
|
||||
def get_process_rank():
|
||||
proc_rank = -1
|
||||
if os.getenv("SLURM_PROCID") is not None:
|
||||
proc_rank = int(os.getenv("SLURM_PROCID"))
|
||||
elif os.getenv("RANK") is not None:
|
||||
# In k8s env, we use $RANK.
|
||||
proc_rank = int(os.getenv("RANK"))
|
||||
|
||||
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
|
||||
return proc_rank
|
||||
|
||||
|
||||
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
|
@ -81,28 +69,12 @@ def move_to_device(data):
|
|||
data_to_return = []
|
||||
for element in data:
|
||||
if isinstance(element, dict):
|
||||
data_to_return.append(
|
||||
{
|
||||
k: (
|
||||
_move_tensor(v)
|
||||
if k != "inference_params"
|
||||
else v._replace(attention_mask=_move_tensor(v.attention_mask))
|
||||
)
|
||||
for k, v in element.items()
|
||||
}
|
||||
)
|
||||
data_to_return.append({k: _move_tensor(v) for k, v in element.items()})
|
||||
else:
|
||||
data_to_return.append(_move_tensor(element))
|
||||
data = data_to_return
|
||||
elif isinstance(data, dict):
|
||||
data = {
|
||||
k: (
|
||||
_move_tensor(v)
|
||||
if k != "inference_params"
|
||||
else v._replace(attention_mask=_move_tensor(v.attention_mask))
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
data = {k: _move_tensor(v) for k, v in data.items()}
|
||||
else:
|
||||
raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
return data
|
||||
|
@ -246,3 +218,21 @@ def get_megatron_flops(
|
|||
|
||||
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
||||
return tflops
|
||||
|
||||
|
||||
class DummyProfile:
|
||||
"""
|
||||
Dummy Profile.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, a, b, c):
|
||||
pass
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.model.metrics import AccPerplex
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
|
||||
if not gpc.is_using_pp():
|
||||
prev_data_process_func = trainer.schedule.data_process_func
|
||||
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
||||
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
|
||||
prev_metric_hooks = trainer.schedule._hooks
|
||||
try:
|
||||
trainer.schedule.data_process_func = None
|
||||
trainer.schedule._grad_accum_size = grad_accum_size
|
||||
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
|
||||
trainer.schedule._hooks = metric_hook_list
|
||||
yield
|
||||
finally:
|
||||
trainer.schedule.data_process_func = prev_data_process_func
|
||||
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
||||
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
|
||||
trainer.schedule._hooks = prev_metric_hooks
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
|
||||
if gpc.is_using_pp():
|
||||
pre_data_process_func = trainer.schedule.data_process_func
|
||||
prev_num_microbatches = trainer.schedule.num_microbatches
|
||||
prev_tensor_shape = trainer.schedule.tensor_shape
|
||||
prev_metric_hooks = trainer.schedule._hooks
|
||||
try:
|
||||
trainer.schedule.data_process_func = None
|
||||
trainer.schedule.num_microbatches = num_microbatches
|
||||
trainer.schedule.tensor_shape = tensor_shape
|
||||
trainer.schedule._hooks = metric_hook_list
|
||||
yield
|
||||
finally:
|
||||
trainer.schedule.data_process_func = pre_data_process_func
|
||||
trainer.schedule.num_microbatches = prev_num_microbatches
|
||||
trainer.schedule.tensor_shape = prev_tensor_shape
|
||||
trainer.schedule._hooks = prev_metric_hooks
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.parallel.sequence_parallel
|
||||
try:
|
||||
gpc.config.parallel.sequence_parallel = False
|
||||
yield
|
||||
finally:
|
||||
gpc.config.parallel.sequence_parallel = prev_mode
|
||||
|
||||
|
||||
def evaluate_on_val_dls(
|
||||
trainer,
|
||||
val_dls,
|
||||
writer,
|
||||
logger,
|
||||
step_count,
|
||||
update_panel: bool = False,
|
||||
streaming: bool = False,
|
||||
):
|
||||
with switch_sequence_parallel_mode():
|
||||
torch.cuda.empty_cache()
|
||||
trainer.eval()
|
||||
verbose = gpc.is_rank_for_log()
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
for val_name, val_dl in val_dls.items():
|
||||
if len(val_dl) == 0 and verbose and not streaming:
|
||||
logger.info(f"Validation dataset: {val_name} is empty")
|
||||
continue
|
||||
|
||||
val_metric = AccPerplex(
|
||||
device=torch.cuda.current_device(),
|
||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||
)
|
||||
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
|
||||
|
||||
val_loss = 0
|
||||
val_idx = -1
|
||||
for val_idx, batch in tqdm(
|
||||
enumerate(val_dl),
|
||||
desc="Val.",
|
||||
total=len(val_dl) if not streaming else None,
|
||||
position=1,
|
||||
disable=not verbose,
|
||||
leave=False,
|
||||
):
|
||||
with torch.inference_mode():
|
||||
if gpc.is_using_pp():
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
num_microbatches=num_microbatches,
|
||||
tensor_shape=tensor_shape,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||
grad_accum_batch_size = data_cfg.micro_bsz
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
grad_accum_batch_size=grad_accum_batch_size,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
if verbose:
|
||||
val_loss += loss.item()
|
||||
|
||||
assert val_idx != -1
|
||||
dist.barrier()
|
||||
|
||||
val_res = val_metric.get_metric()
|
||||
if verbose and len(val_dl) != 0:
|
||||
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
||||
infos = {
|
||||
"step": step_count,
|
||||
f"val/{val_name}_loss": val_loss,
|
||||
f"val/{val_name}_acc": val_res["acc"],
|
||||
f"val/{val_name}_plex": val_res["perplexity"],
|
||||
}
|
||||
|
||||
for key, value in infos.items():
|
||||
writer.add_scalar(key=key, value=value, step=step_count)
|
||||
|
||||
if update_panel:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
||||
extra={
|
||||
"step": step_count,
|
||||
"val_loss": val_loss,
|
||||
"val_acc": val_res["acc"],
|
||||
"val_perplexity": val_res["perplexity"],
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
|
@ -2,6 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
LOGGER_NAME = "internlm"
|
||||
LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s"
|
||||
|
@ -11,6 +12,8 @@ LOGGER_LEVEL_HELP = (
|
|||
"The logging level threshold, choices=['debug', 'info', 'warning', 'error', 'critical'], default='info'"
|
||||
)
|
||||
|
||||
uniscale_logger = None
|
||||
|
||||
|
||||
def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL) -> logging.Logger:
|
||||
"""Configure the logger that is used for uniscale framework.
|
||||
|
@ -24,6 +27,10 @@ def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL
|
|||
logger (logging.Logger): the created or modified logger.
|
||||
|
||||
"""
|
||||
|
||||
if uniscale_logger is not None:
|
||||
return uniscale_logger
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
|
||||
if logging_level not in LOGGER_LEVEL_CHOICES:
|
||||
|
@ -39,3 +46,53 @@ def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL
|
|||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def initialize_uniscale_logger(
|
||||
job_name: str = None,
|
||||
launch_time: str = None,
|
||||
file_name: str = None,
|
||||
name: str = LOGGER_NAME,
|
||||
level: str = LOGGER_LEVEL,
|
||||
file_path: str = None,
|
||||
is_std: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize uniscale logger.
|
||||
|
||||
Args:
|
||||
job_name (str): The name of training job, defaults to None.
|
||||
launch_time (str): The launch time of training job, defaults to None.
|
||||
file_name (str): The log file name, defaults to None.
|
||||
name (str): The logger name, defaults to "internlm".
|
||||
level (str): The log level, defaults to "info".
|
||||
file_path (str): The log file path, defaults to None.
|
||||
is_std (bool): Whether to output to console, defaults to True.
|
||||
|
||||
Returns:
|
||||
Uniscale logger instance.
|
||||
"""
|
||||
|
||||
try:
|
||||
from uniscale_monitoring import get_logger as get_uniscale_logger
|
||||
except ImportError:
|
||||
print("Failed to import module uniscale_monitoring. Use default python logger.")
|
||||
return None
|
||||
|
||||
if not file_path:
|
||||
assert (
|
||||
job_name and launch_time and file_name
|
||||
), "If file_path is None, job_name, launch_time and file_name must be setted."
|
||||
log_file_name = file_name
|
||||
log_folder = os.path.join(job_name, launch_time, "logs")
|
||||
log_dir = os.path.join(log_folder, log_file_name)
|
||||
file_path = log_dir
|
||||
|
||||
logger = get_uniscale_logger(name=name, level=level, filename=file_path, is_std=is_std)
|
||||
if isinstance(logger, (list, tuple)):
|
||||
logger = list(logger)[0]
|
||||
|
||||
global uniscale_logger
|
||||
uniscale_logger = logger
|
||||
|
||||
return logger
|
||||
|
|
|
@ -14,18 +14,19 @@ class _Timer:
|
|||
self.elapsed_ = 0.0
|
||||
self.started_ = False
|
||||
self.start_time = time.time()
|
||||
self.stream = torch.cuda.current_stream()
|
||||
|
||||
def start(self):
|
||||
"""Start the timer."""
|
||||
assert not self.started_, "timer has already been started"
|
||||
torch.cuda.synchronize()
|
||||
self.stream.synchronize()
|
||||
self.start_time = time.time()
|
||||
self.started_ = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the timer."""
|
||||
assert self.started_, "timer is not started"
|
||||
torch.cuda.synchronize()
|
||||
self.stream.synchronize()
|
||||
self.elapsed_ += time.time() - self.start_time
|
||||
self.started_ = False
|
||||
|
||||
|
|
|
@ -2,8 +2,11 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import fcntl
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
@ -11,15 +14,26 @@ import torch
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
|
||||
from internlm.utils.storage_manager import (
|
||||
get_fns,
|
||||
get_storage_manager,
|
||||
llm_load,
|
||||
llm_save,
|
||||
)
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class CheckpointType(Enum):
|
||||
NORMAL_CHECKPOINT = 1
|
||||
SNAPSHOT_CHECKPOINT = 2
|
||||
|
||||
|
||||
def get_model_topology(model):
|
||||
"""
|
||||
Returns:
|
||||
|
@ -138,11 +152,13 @@ def save_optimizer_checkpoint(optim, state_path):
|
|||
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
||||
|
||||
states = optim.state_dict()
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
if gpc.get_global_rank() < optim.zero_world_size:
|
||||
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
|
||||
llm_save(os.path.join(state_path, fp), states)
|
||||
if "zero_devide_optim_plan" in states:
|
||||
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
||||
|
@ -152,44 +168,6 @@ def save_optimizer_checkpoint(optim, state_path):
|
|||
llm_save(os.path.join(state_path, fp), states)
|
||||
|
||||
|
||||
def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
|
||||
"""
|
||||
Save checkpoint to the given folder path.
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
torch.distributed.barrier()
|
||||
folder = os.path.join(folder, str(train_state.step_count))
|
||||
logger.info(
|
||||
f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..."
|
||||
)
|
||||
|
||||
timer("save-model").start()
|
||||
save_model_checkpoint(folder=folder, model=model)
|
||||
timer("save-model").stop()
|
||||
|
||||
timer("save-optimizer").start()
|
||||
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
|
||||
timer("save-optimizer").stop()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
|
||||
sampler_state = train_state.batch_sampler.state_dict()
|
||||
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
|
||||
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
||||
|
||||
if model_config is not None:
|
||||
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
timer.log(["save-model", "save-optimizer"], logger=logger)
|
||||
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
|
||||
|
||||
|
||||
def load_optimizer_checkpoint(folder, optim):
|
||||
"""Load the optimizer state from the local file system or remote
|
||||
object storage Service (OSS).
|
||||
|
@ -287,3 +265,369 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
|
|||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"reload load_scheduler:{lr_scheduler}")
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""StorageManagerContext"""
|
||||
|
||||
def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
|
||||
"""
|
||||
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
|
||||
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
||||
for the asynchronous ckpt upload to complete.
|
||||
|
||||
Args:
|
||||
ckpt_config (dict): model checkpoint config.
|
||||
model (nn.module): model obj
|
||||
optimizer (object): optimzier obj.
|
||||
lr_scheduler (object): lr_scheduler obj.
|
||||
model_config (dict): model config.
|
||||
"""
|
||||
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
|
||||
self.checkpoint_every = ckpt_config.checkpoint_every
|
||||
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
|
||||
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
|
||||
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
|
||||
self.stop_file_path = ckpt_config.stop_file_path
|
||||
self.load_model_only_folder = ckpt_config.load_model_only_folder
|
||||
self.feishu_address = feishu_address
|
||||
self.storage_manager = get_storage_manager()
|
||||
self.snapshot_counter = 0
|
||||
self.load_optimizer = gpc.config.ckpt.load_optimizer
|
||||
|
||||
self.model = model
|
||||
self.model_config = model_config
|
||||
self.model_config_file = model_config_file
|
||||
|
||||
if self.stop_file_path and gpc.get_global_rank() == 0:
|
||||
dir_path = os.path.dirname(self.stop_file_path)
|
||||
if dir_path != "" and not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
with open(self.stop_file_path, "w", encoding="utf-8") as f:
|
||||
f.write("0")
|
||||
|
||||
if ckpt_config.load_given_ckpt is False:
|
||||
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||
latest_ckpt_path = self.query_lastest_ckpt()
|
||||
if latest_ckpt_path:
|
||||
self.load_ckpt_folder = latest_ckpt_path
|
||||
else:
|
||||
# At this time, we have to load model init weights and train from step 0.
|
||||
self.load_ckpt_folder = self.load_model_only_folder
|
||||
else:
|
||||
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
|
||||
if self.stop_file_path is None:
|
||||
logger.warning("no set stop_file_path, quit_signal_handler is disable")
|
||||
|
||||
def quit_signal_handler(self, train_state) -> bool:
|
||||
"""
|
||||
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
|
||||
all ranks will save ckpt and exit.
|
||||
Negative integer step means save ckpt.
|
||||
Positive integer step means save ckpt and quit.
|
||||
|
||||
Args:
|
||||
train_state (TrainState):
|
||||
Returns:
|
||||
bool: whether to quit.
|
||||
"""
|
||||
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
|
||||
|
||||
if self.stop_file_path is None:
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
||||
with open(self.stop_file_path, "a+", encoding="utf-8") as f:
|
||||
fcntl.flock(f, fcntl.LOCK_EX)
|
||||
f.seek(0)
|
||||
msg = f.read()
|
||||
fcntl.flock(f, fcntl.LOCK_UN)
|
||||
action_step = int(msg)
|
||||
|
||||
if action_step < 0 and abs(action_step) == train_state.step_count:
|
||||
now_save_ckpt = True
|
||||
|
||||
if action_step > 0 and action_step == train_state.step_count:
|
||||
now_break, now_save_ckpt = True, True
|
||||
|
||||
if action_step != 0 and gpc.is_rank_for_log():
|
||||
msg = "Stop" if action_step > 0 else "Save"
|
||||
action_step = abs(action_step)
|
||||
if train_state.step_count <= action_step:
|
||||
if self.feishu_address:
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"training will {msg} at step_count {action_step}!\
|
||||
now step_count is {train_state.step_count}",
|
||||
)
|
||||
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
||||
def try_save_checkpoint(self, train_state):
|
||||
if not self.enable_save_ckpt:
|
||||
return False
|
||||
|
||||
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
|
||||
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
||||
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
|
||||
if train_state.step_count % self.checkpoint_every == 0:
|
||||
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
|
||||
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
|
||||
if save_ckpts is False:
|
||||
save_ckpts = singal_save_ckpts
|
||||
save_type = singal_save_type
|
||||
|
||||
if save_ckpts:
|
||||
# Wait for the previous round of asynchronous upload storage to complete.
|
||||
self.storage_manager.wait()
|
||||
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
|
||||
# Snapshot number, with only two snapshots written alternately.
|
||||
self.snapshot_counter = (self.snapshot_counter + 1) % 2
|
||||
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
|
||||
else:
|
||||
save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count))
|
||||
|
||||
self.save_checkpoint(
|
||||
folder=save_ckpt_folder,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
scheduler=self.lr_scheduler,
|
||||
train_state=train_state,
|
||||
model_config=self.model_config,
|
||||
model_config_file=self.model_config_file,
|
||||
)
|
||||
|
||||
return now_break
|
||||
|
||||
def wait_async_upload_finish(self):
|
||||
"""wait for all checkpoint uploads to be completed"""
|
||||
self.storage_manager.wait()
|
||||
torch.distributed.barrier()
|
||||
|
||||
def query_latest_snapshot_step_boto3(self):
|
||||
"""query_latest_snapshot_step_boto3
|
||||
Returns:
|
||||
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
|
||||
"""
|
||||
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
|
||||
if len(ckpt_list) == 0:
|
||||
return None, None
|
||||
|
||||
max_normal_step = 0
|
||||
ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list))
|
||||
ckpt_list.sort(reverse=True)
|
||||
for ckpt in ckpt_list:
|
||||
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
|
||||
for fn in fns_list:
|
||||
if fn.endswith(".step"):
|
||||
max_normal_step = ckpt
|
||||
break
|
||||
if max_normal_step != 0:
|
||||
break
|
||||
|
||||
max_normal_step = ckpt_list[0]
|
||||
load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step))
|
||||
|
||||
snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0")
|
||||
snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1")
|
||||
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
|
||||
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
|
||||
max_step_0, max_step_1 = 0, 0
|
||||
for ckpt in ckpt_list_1:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
||||
for ckpt in ckpt_list_2:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
||||
|
||||
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
|
||||
snap_step = max(max_step_0, max_step_1)
|
||||
load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path
|
||||
load_step = max(snap_step, max_normal_step)
|
||||
return load_path, load_step
|
||||
|
||||
def query_latest_snapshot_step_local(self):
|
||||
max_step, max_step_path = 0, None
|
||||
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
|
||||
for fn in files:
|
||||
fn = fn.strip("/")
|
||||
if fn.endswith(".step"):
|
||||
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
|
||||
# as an integrity flag.
|
||||
step = int(fn.rsplit(".", maxsplit=1)[0])
|
||||
if max_step < step:
|
||||
max_step = step
|
||||
max_step_path = root
|
||||
|
||||
return max_step_path, max_step
|
||||
|
||||
def query_lastest_ckpt(self):
|
||||
latest_checkpoint = None
|
||||
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
|
||||
if self.save_ckpt_folder:
|
||||
if self.save_ckpt_folder.startswith("boto3"):
|
||||
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
|
||||
elif self.save_ckpt_folder.startswith("local"):
|
||||
latest_checkpoint, step = self.query_latest_snapshot_step_local()
|
||||
else:
|
||||
latest_checkpoint, step = None, 0
|
||||
|
||||
if latest_checkpoint is not None:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
|
||||
)
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
|
||||
)
|
||||
|
||||
return latest_checkpoint
|
||||
|
||||
def try_load_model(self, current_time=""):
|
||||
model_load_path = None
|
||||
|
||||
if self.load_ckpt_folder and self.load_model_only_folder:
|
||||
raise ValueError(
|
||||
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
|
||||
if you only need to load model weights (for example starting an SFT task for the first time), \
|
||||
set load_model_only_folder path, if you need to resume training from ckpt, \
|
||||
set load_ckpt_folder or use default value \
|
||||
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
|
||||
)
|
||||
|
||||
if self.load_ckpt_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = self.load_ckpt_folder
|
||||
elif self.load_model_only_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = self.load_model_only_folder
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
if model_load_path is not None:
|
||||
load_model_checkpoint(folder=model_load_path, model=self.model)
|
||||
|
||||
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
|
||||
"""Attempt to restore the training state of the last ckpt.
|
||||
|
||||
Args:
|
||||
lr_scheduler (_LRScheduler): lr_scheduler object.
|
||||
optimizer (Optimizer): optimizer object.
|
||||
lr (float): learning rate.
|
||||
train_state (dict): traing states.
|
||||
train_dl (DataLoader): traning dataloader object
|
||||
"""
|
||||
if self.load_ckpt_folder is not None:
|
||||
# load optimzier states.
|
||||
if self.load_optimizer:
|
||||
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
|
||||
# load lr scheduler states.
|
||||
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
||||
# load training states.
|
||||
load_context(self.load_ckpt_folder, train_dl, train_state)
|
||||
# load dataloader sampler states.
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
|
||||
if hasattr(train_state, "data_state_dict"):
|
||||
train_dl.dataset.load_state_dict(
|
||||
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
|
||||
)
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
folder,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
train_state: TrainState,
|
||||
model_config: Dict = None,
|
||||
model_config_file: str = None,
|
||||
):
|
||||
"""
|
||||
Save checkpoint to the given folder path.
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
self.set_save_folder(folder, train_state.step_count)
|
||||
torch.cuda.synchronize()
|
||||
torch.distributed.barrier()
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")
|
||||
|
||||
timer("save-model").start()
|
||||
save_model_checkpoint(folder=folder, model=model)
|
||||
timer("save-model").stop()
|
||||
|
||||
timer("save-optimizer").start()
|
||||
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
|
||||
timer("save-optimizer").stop()
|
||||
|
||||
if (
|
||||
hasattr(train_state, "data_state_dict")
|
||||
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
and gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
):
|
||||
llm_save(
|
||||
os.path.join(folder, f"sampler_{gpc.get_local_rank(ParallelMode.DATA)}.pt"),
|
||||
saved_obj=train_state.data_state_dict,
|
||||
)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
sampler_state = train_state.batch_sampler.state_dict()
|
||||
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
|
||||
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
||||
|
||||
if model_config is not None:
|
||||
# Model configuration dictionary.
|
||||
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
||||
|
||||
if model_config_file is not None:
|
||||
# The complete training config file content, stored in binary format.
|
||||
llm_save(os.path.join(folder, "config_file.pt"), saved_obj=model_config_file)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
timer.log(["save-model", "save-optimizer"], logger=logger)
|
||||
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
|
||||
if self.storage_manager.async_mode is False:
|
||||
llm_save(
|
||||
os.path.join(folder, f"{train_state.step_count}.step"),
|
||||
saved_obj=dict({"step": train_state.step_count}),
|
||||
)
|
||||
|
||||
def set_save_folder(self, folder, step):
|
||||
self.storage_manager.latest_save_folder = folder
|
||||
self.storage_manager.latest_save_step = step
|
||||
|
|
|
@ -46,3 +46,16 @@ def sync_model_param_within_tp(model):
|
|||
|
||||
def is_no_pp_or_last_stage():
|
||||
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
||||
|
||||
|
||||
def get_parallel_log_file_name():
|
||||
if gpc.is_rank_for_log():
|
||||
fn_prefix = "main_" # Indicates a rank with more output information
|
||||
else:
|
||||
fn_prefix = ""
|
||||
|
||||
log_file_name = (
|
||||
f"{fn_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}"
|
||||
)
|
||||
return log_file_name
|
||||
|
|
|
@ -22,9 +22,9 @@ class Registry:
|
|||
"""Registers a module represented in `module_class`.
|
||||
|
||||
Args:
|
||||
module_class (class): The module to be registered.
|
||||
module_name (str): The name of module to be registered.
|
||||
Returns:
|
||||
class: The module to be registered, so as to use it normally if via importing.
|
||||
function: The module to be registered, so as to use it normally if via importing.
|
||||
Raises:
|
||||
AssertionError: Raises an AssertionError if the module has already been registered before.
|
||||
"""
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from functools import partial, reduce
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import pyecharts
|
||||
import torch
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
|
||||
mb = 1024 * 1024
|
||||
|
||||
|
@ -107,6 +105,8 @@ class SimpleMemState:
|
|||
"""
|
||||
Update the total memory usage of the model and sub-models.
|
||||
"""
|
||||
self._total_mem = self._layer_mem
|
||||
|
||||
for stat in self.sub_model_stats.values():
|
||||
# Update sub-model status first.
|
||||
stat.update_total_memory()
|
||||
|
@ -169,6 +169,39 @@ class SimpleMemState:
|
|||
return {"name": self.layer_name, "children": children}
|
||||
|
||||
|
||||
class ActivationMemState:
|
||||
"""
|
||||
Activation Memory State
|
||||
"""
|
||||
|
||||
def __init__(self, num_chunks: int) -> None:
|
||||
self._num_chunks = num_chunks
|
||||
|
||||
self.inited: List[bool] = [False for _ in range(num_chunks)]
|
||||
self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)]
|
||||
|
||||
@property
|
||||
def total_mem(self) -> int:
|
||||
return sum(state.total_mem for state in self.states)
|
||||
|
||||
def dump(self, prefix: str = "") -> str:
|
||||
return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states])
|
||||
|
||||
def to_json(self, base: int = 1024 * 1024) -> List:
|
||||
return [state.to_json(base) for state in self.states]
|
||||
|
||||
|
||||
def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]:
|
||||
num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1
|
||||
|
||||
if num_chunks > 1:
|
||||
model = torch.nn.ModuleList([_model.model if isinstance(_model, NaiveAMPModel) else _model for _model in model])
|
||||
else:
|
||||
model = model.model if isinstance(model, NaiveAMPModel) else model
|
||||
|
||||
return model, num_chunks
|
||||
|
||||
|
||||
class SimpleMemoryProfiler:
|
||||
"""
|
||||
A memory profiler for a llm model.
|
||||
|
@ -177,7 +210,7 @@ class SimpleMemoryProfiler:
|
|||
model (torch.nn.Module): The model to profile.
|
||||
optimizer (torch.optim.Optimizer): The optimizer used for training the model.
|
||||
log_file (str): The file to write the memory state information to.
|
||||
activation_config (List[str], optional): The list of activation layers to track. Defaults to None.
|
||||
total_steps: number of steps to trace.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -186,9 +219,8 @@ class SimpleMemoryProfiler:
|
|||
optimizer: torch.optim.Optimizer,
|
||||
log_folder: str,
|
||||
total_steps: int = 5,
|
||||
activation_config: List[str] = None,
|
||||
):
|
||||
self._model = model
|
||||
self._model, self._num_model_chunks = _unpack_naive_wrapper(model)
|
||||
self._optimizer = optimizer
|
||||
self._log_folder = log_folder
|
||||
self._remaining_steps = total_steps
|
||||
|
@ -197,17 +229,20 @@ class SimpleMemoryProfiler:
|
|||
self._record_start_time = time.time()
|
||||
|
||||
# For activation memory state.
|
||||
self._activation_config = activation_config
|
||||
self._activation_mem_inited: bool = False
|
||||
|
||||
self._activation_mem: int = 0
|
||||
self._activation_max_count = 0
|
||||
self._activation_base_mem: SimpleMemState = SimpleMemState("activations")
|
||||
self._activation_mem_max: int = 0
|
||||
self._activation_base_mems = ActivationMemState(self._num_model_chunks)
|
||||
|
||||
# Check or create log folder
|
||||
os.makedirs(self._log_folder, exist_ok=True)
|
||||
|
||||
# Register activation memory tracking hooks
|
||||
self._register_activation_trace_hooks()
|
||||
if self._num_model_chunks > 1:
|
||||
for chunk_id in range(self._num_model_chunks):
|
||||
self._register_activation_trace_hooks(chunk_id, self._model[chunk_id])
|
||||
else:
|
||||
self._register_activation_trace_hooks(0, self._model)
|
||||
|
||||
# Calculate static parameter cuda memory
|
||||
self._param_mem_state = SimpleMemState("param_mem")
|
||||
|
@ -221,7 +256,7 @@ class SimpleMemoryProfiler:
|
|||
self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups)))
|
||||
|
||||
# Generate the first memory record
|
||||
self.point(create=True)
|
||||
self.point(with_options="params,grads,os_params", create=True)
|
||||
|
||||
def point(self, with_options: str = "", create: bool = False) -> None:
|
||||
"""
|
||||
|
@ -272,7 +307,7 @@ class SimpleMemoryProfiler:
|
|||
if "os_state" in options:
|
||||
layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump()
|
||||
if "activation_base" in options:
|
||||
layout_info += "activation_base_layout:\n" + self._activation_base_mem.dump()
|
||||
layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump()
|
||||
|
||||
# Write memory state information to log file
|
||||
file_mode = "w" if create else "a"
|
||||
|
@ -315,14 +350,14 @@ class SimpleMemoryProfiler:
|
|||
[self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()],
|
||||
"os_memory_sunburst",
|
||||
)
|
||||
self._render_sunburst_chart(self._activation_base_mem.to_json()["children"], "activation_memory_sunburst")
|
||||
self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst")
|
||||
# Generate summary sunburst chart
|
||||
summary_sunburst_data = [
|
||||
{"name": "params", "value": self._param_mem_state.total_mem // mb},
|
||||
{"name": "grads", "value": self._grad_mem_state.total_mem // mb},
|
||||
{"name": "os_params", "value": self._os_params_mem_state.total_mem // mb},
|
||||
{"name": "os_state", "value": self._os_state_mem_state.total_mem // mb},
|
||||
{"name": "activation", "value": self._activation_base_mem.total_mem // mb},
|
||||
{"name": "activation", "value": self._activation_mem_max // mb},
|
||||
]
|
||||
|
||||
self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst")
|
||||
|
@ -337,12 +372,13 @@ class SimpleMemoryProfiler:
|
|||
{},
|
||||
{
|
||||
"r0": "10%",
|
||||
"r": "40%",
|
||||
"r": "35%",
|
||||
"itemStyle": {"borderWidth": 3},
|
||||
"label": {"align": "left"},
|
||||
},
|
||||
{"r0": "40%", "r": "65%", "label": {"align": "left"}},
|
||||
{"r0": "65%", "r": "80%", "label": {"align": "left"}},
|
||||
{"r0": "35%", "r": "55%", "label": {"align": "left"}},
|
||||
{"r0": "55%", "r": "70%", "label": {"align": "left"}},
|
||||
{"r0": "70%", "r": "80%", "label": {"align": "left"}},
|
||||
{"r0": "80%", "r": "90%", "label": {"align": "left"}},
|
||||
{
|
||||
"r0": "90%",
|
||||
|
@ -357,7 +393,14 @@ class SimpleMemoryProfiler:
|
|||
f"{self._log_folder}/{name}.html"
|
||||
)
|
||||
|
||||
def _inner_activation_trace_hook(self, layer_name: str, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
||||
def _inner_activation_trace_hook(
|
||||
self,
|
||||
chunk_id: int,
|
||||
layer_name: str,
|
||||
model: Any,
|
||||
inputs: Any,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Hook function to trace the activation memory usage for a inner layer.
|
||||
|
||||
|
@ -373,13 +416,15 @@ class SimpleMemoryProfiler:
|
|||
del model, inputs
|
||||
assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}"
|
||||
|
||||
if self._stoped or self._activation_mem_inited:
|
||||
if self._stoped or self._activation_base_mems.inited[chunk_id]:
|
||||
return
|
||||
|
||||
# Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook.
|
||||
self._activation_base_mem.add(layer_name, output.element_size() * output.nelement(), flush=False)
|
||||
self._activation_base_mems.states[chunk_id].add(
|
||||
layer_name, output.element_size() * output.nelement(), flush=False
|
||||
)
|
||||
|
||||
def _activation_trace_hook_forward(self, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
||||
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
||||
"""
|
||||
Hook function to trace the activation memory usage for a forward pass.
|
||||
|
||||
|
@ -398,23 +443,24 @@ class SimpleMemoryProfiler:
|
|||
return
|
||||
|
||||
# Check if the activation memory has been initialized
|
||||
if self._activation_mem_inited is False:
|
||||
if self._activation_base_mems.inited[chunk_id] is False:
|
||||
self._activation_base_mems.inited[chunk_id] = True
|
||||
# Update the total memory of the activation base memory state
|
||||
self._activation_base_mem.update_total_memory()
|
||||
self._activation_base_mems.states[chunk_id].update_total_memory()
|
||||
# Set with_options to "activation_base" to include activation_base_layout in the memory dump
|
||||
self._activation_mem_inited = True
|
||||
with_options = "activation_base"
|
||||
else:
|
||||
with_options = ""
|
||||
|
||||
# Accumulate activation memory usage for each forward pass
|
||||
self._activation_mem += self._activation_base_mem.total_mem
|
||||
|
||||
# Update activation max count
|
||||
if self._activation_mem // self._activation_base_mem.total_mem > self._activation_max_count:
|
||||
self._activation_max_count = self._activation_mem // self._activation_base_mem.total_mem
|
||||
self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem
|
||||
if self._activation_mem > self._activation_mem_max:
|
||||
self._activation_mem_max = self._activation_mem
|
||||
|
||||
# Trigger a memory record
|
||||
self.point()
|
||||
self.point(with_options)
|
||||
|
||||
def _activation_tarce_hook_backward(self, model: Any, inputs: Any, grad_outputs: Any) -> None:
|
||||
def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None:
|
||||
"""
|
||||
Hook function to trace the activation memory usage for a backward pass.
|
||||
|
||||
|
@ -432,37 +478,28 @@ class SimpleMemoryProfiler:
|
|||
return
|
||||
|
||||
# Release activation memory usage for each backward pass
|
||||
self._activation_mem -= self._activation_base_mem.total_mem
|
||||
self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem
|
||||
|
||||
# Trigger a memory record
|
||||
self.point()
|
||||
|
||||
def _register_activation_trace_hooks(self) -> None:
|
||||
def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None:
|
||||
"""
|
||||
Register activation trace hooks for the model and each submodule in the model.
|
||||
"""
|
||||
|
||||
# Register inner activation trace hooks for each submodule in the model
|
||||
for layer_name in self._activation_config:
|
||||
# Register a hook for every activation
|
||||
model = self._model
|
||||
sub_models = layer_name.split(".")
|
||||
# Get the target sub-model
|
||||
for sub_model_name in sub_models:
|
||||
try:
|
||||
model = model.get_submodule(sub_model_name)
|
||||
except AttributeError:
|
||||
model = None
|
||||
break
|
||||
|
||||
for layer_name, sub_model in model_chunk.named_modules():
|
||||
# Register the hook
|
||||
if model is not None:
|
||||
model.register_forward_hook(partial(self._inner_activation_trace_hook, layer_name))
|
||||
if len(sub_model._modules) != 0:
|
||||
continue # TODO: in some special cases, we may need some additional configuration to correct
|
||||
|
||||
sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name))
|
||||
|
||||
# Register a forward hook for the main model to track activation memory usage
|
||||
self._model.register_forward_hook(self._activation_trace_hook_forward)
|
||||
model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id))
|
||||
# Register a backward hook for the main model to release activation memory usage
|
||||
self._model.register_full_backward_hook(self._activation_tarce_hook_backward)
|
||||
model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id))
|
||||
|
||||
def _calc_tensor_memory(
|
||||
self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False
|
||||
|
@ -554,48 +591,6 @@ class SimpleMemoryProfiler:
|
|||
self._calc_tensor_memory(root_stat, named_tensors)
|
||||
|
||||
|
||||
def build_activation_config(num_layers: int, num_chunks: int = 1) -> List[str]:
|
||||
# TODO: support interleaved pipeline scheduling.
|
||||
assert num_chunks == 1, "Only support num_chunks == 1"
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
else:
|
||||
pipeline_size = 1
|
||||
pipeline_rank = 0
|
||||
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
start, end = parts[0]
|
||||
num_blocks = end - start
|
||||
|
||||
block_conf_tmpl = [
|
||||
"mixer.rotary_emb",
|
||||
"mixer.Wqkv",
|
||||
"mixer.inner_attn",
|
||||
"mixer.inner_cross_attn",
|
||||
"mixer.out_proj",
|
||||
# "dropout1", # skip when dropout_selective_checkpoint is True
|
||||
# "dropout2", # skip when dropout_selective_checkpoint is True
|
||||
"norm1",
|
||||
"norm2",
|
||||
"mlp.w1",
|
||||
"mlp.w2",
|
||||
"mlp.w3",
|
||||
]
|
||||
|
||||
block_conf = []
|
||||
for block_id in range(num_blocks):
|
||||
block_conf += [f"blocks.{block_id}.{layer}" for layer in block_conf_tmpl]
|
||||
|
||||
# We don't need to care about whether the embedding, norm, and head layers exist in the model after partitioning.
|
||||
# If they don't exist, they will be automatically ignored when registering activation trace hooks.
|
||||
activation_conf = ["embedding", "norm", "head"] + block_conf
|
||||
|
||||
return activation_conf
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
|
@ -635,32 +630,39 @@ if __name__ == "__main__":
|
|||
|
||||
return output
|
||||
|
||||
def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor:
|
||||
if _num_chunks > 1:
|
||||
_output = _input
|
||||
for _model_chunk in _model_chunks:
|
||||
_output = _model_chunk(_output)
|
||||
else:
|
||||
_output = _model_chunks(_input)
|
||||
|
||||
return _output
|
||||
|
||||
# num_chunks config
|
||||
_num_chunks = 1
|
||||
|
||||
# init model and optimizer
|
||||
_model: torch.nn.Module = SimpleModel()
|
||||
if _num_chunks > 1:
|
||||
_chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)]
|
||||
_model = torch.nn.ModuleList(_chunks).cuda()
|
||||
else:
|
||||
_model: torch.nn.Module = SimpleModel().cuda()
|
||||
_optimizer = torch.optim.Adam(_model.parameters())
|
||||
|
||||
# create activation config for simple model layer by layer.
|
||||
activation_configs = [
|
||||
# model level 0
|
||||
"layer1",
|
||||
"layer2",
|
||||
"layer3",
|
||||
# model level 1
|
||||
"layer2.layer1",
|
||||
"layer2.layer3",
|
||||
]
|
||||
|
||||
_model.modules()
|
||||
|
||||
# init profiler
|
||||
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler.log", activation_configs)
|
||||
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler", total_steps=1)
|
||||
|
||||
_optimizer.zero_grad()
|
||||
|
||||
x1 = torch.randn((128, 5120))
|
||||
x2 = torch.randn((128, 5120))
|
||||
out1 = _model(x1)
|
||||
out2 = _model(x2)
|
||||
# inputs
|
||||
x1 = torch.randn((128, 5120)).cuda()
|
||||
x2 = torch.randn((128, 5120)).cuda()
|
||||
# forward
|
||||
out1 = _simple_schedule(_num_chunks, _model, x1)
|
||||
out2 = _simple_schedule(_num_chunks, _model, x2)
|
||||
# backward
|
||||
out1.mean().backward()
|
||||
out2.mean().backward()
|
||||
|
||||
|
|
|
@ -1,21 +1,34 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import socket
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Union
|
||||
import stat
|
||||
from asyncio import InvalidStateError
|
||||
from asyncio.tasks import ALL_COMPLETED
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Union
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
try:
|
||||
import boto3
|
||||
import botocore
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
||||
|
@ -41,10 +54,6 @@ def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
|
|||
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
|
||||
|
||||
|
||||
class CheckpointType(Enum):
|
||||
NORMAL_CHECKPOINT = 1
|
||||
|
||||
|
||||
class StorageClient:
|
||||
"""
|
||||
StorageClient as a client for s3 storage access.
|
||||
|
@ -54,7 +63,7 @@ class StorageClient:
|
|||
self.handler = handler
|
||||
|
||||
@staticmethod
|
||||
def load(client, load_path: str, map_location):
|
||||
def load(client, load_path: str, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
|
@ -71,25 +80,51 @@ class StorageClient:
|
|||
|
||||
|
||||
class Boto3MetaInfo:
|
||||
def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None:
|
||||
self.client = client
|
||||
"""Boto3 meta info for save/load etc."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_async,
|
||||
handler: StorageClient,
|
||||
bucket_name: str,
|
||||
endpoint: str,
|
||||
file_path: str,
|
||||
async_upload_fn: callable,
|
||||
local_nvme_path=None,
|
||||
) -> None:
|
||||
self.is_async = is_async
|
||||
self.client = handler
|
||||
self.bucket_name = bucket_name
|
||||
self.endpoint = endpoint
|
||||
self.file_path = file_path
|
||||
self.async_upload_fn = async_upload_fn
|
||||
self.local_nvme_path = local_nvme_path
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
|
||||
local_nvme_path: {self.local_nvme_path}"
|
||||
|
||||
|
||||
class LocalMetaInfo:
|
||||
def __init__(self, client: StorageClient, dest_path: str) -> None:
|
||||
self.client = client
|
||||
"""Local meta info for save/load etc."""
|
||||
|
||||
def __init__(self, handler: StorageClient, dest_path: str) -> None:
|
||||
self.is_async = False
|
||||
self.client = handler
|
||||
self.dest_path = dest_path
|
||||
self.async_upload_fn = None
|
||||
|
||||
|
||||
def unpack_meta(meta):
|
||||
args = []
|
||||
is_async = meta.is_async
|
||||
for k, v in meta.__dict__.items():
|
||||
if k == "endpoint":
|
||||
if k in ("endpoint", "async_upload_fn", "is_async"):
|
||||
continue
|
||||
if not is_async and k in ("local_nvme_path",):
|
||||
continue
|
||||
args.append(v)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
|
@ -101,21 +136,6 @@ def compute_file_md5_by_chunk(file_name: str):
|
|||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def get_boto3_meta(fp: str) -> Boto3MetaInfo:
|
||||
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
|
||||
parts = fp.lstrip("s3://").split(os.path.sep)
|
||||
match = boto3_url_re.match(parts[0])
|
||||
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
||||
bucket_name, endpoint = match.group(1), match.group(2)
|
||||
endpoint = "http://" + endpoint + ":80"
|
||||
return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:]))
|
||||
|
||||
|
||||
def get_local_meta(fp: str) -> LocalMetaInfo:
|
||||
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
||||
return LocalMetaInfo(None, fp)
|
||||
|
||||
|
||||
class Boto3Client(StorageClient):
|
||||
"""
|
||||
Boto3Client
|
||||
|
@ -169,7 +189,9 @@ class Boto3Client(StorageClient):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs):
|
||||
def sync_upload_fileobj(
|
||||
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
|
||||
): # pylint: disable=W0613
|
||||
assert saved_obj is not None, "saved_obj is None!"
|
||||
try:
|
||||
with io.BytesIO() as f:
|
||||
|
@ -182,7 +204,14 @@ class Boto3Client(StorageClient):
|
|||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict:
|
||||
def load(
|
||||
handler,
|
||||
bucket_name: str,
|
||||
fp: str,
|
||||
local_nvme_path: str, # pylint: disable=W0613
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
|
||||
|
@ -191,7 +220,7 @@ class Boto3Client(StorageClient):
|
|||
with io.BytesIO() as f:
|
||||
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
|
||||
f.seek(0)
|
||||
states = torch.load(f, *args, map_location=map_location, **kwargs)
|
||||
states = torch.load(f, *args, **kwargs)
|
||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||
|
@ -199,28 +228,40 @@ class Boto3Client(StorageClient):
|
|||
return states
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(
|
||||
handler,
|
||||
bucket_name: str,
|
||||
fp: str,
|
||||
):
|
||||
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
||||
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, bucket_name: str, fp: str):
|
||||
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
|
||||
"""
|
||||
Ref: https://stackoverflow.com/questions/54314563/
|
||||
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
||||
"""
|
||||
paginator = handler.client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||
|
||||
folder_name_list = []
|
||||
for page in pages:
|
||||
for obj in page["Contents"]:
|
||||
fp: str = obj["Key"]
|
||||
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
|
||||
return folder_name_list
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
pth: str = obj["Key"]
|
||||
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||
return list(set(folder_name_list))
|
||||
|
||||
@staticmethod
|
||||
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||
try:
|
||||
with open(local_nvme_path, "rb") as f:
|
||||
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
|
||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||
) from exc
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def delete_obj(handler, fp: str):
|
||||
raise NotImplementedError("boto3 not support delete_obj")
|
||||
|
||||
|
||||
class LocalClient(StorageClient):
|
||||
|
@ -241,11 +282,11 @@ class LocalClient(StorageClient):
|
|||
torch.save(saved_obj, fp, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(handler, fp: str, *args, map_location="cpu", **kwargs):
|
||||
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert os.path.exists(fp), f"{fp} is not found!"
|
||||
with open(fp, "rb") as f:
|
||||
states = torch.load(f, map_location=map_location, *args, **kwargs)
|
||||
states = torch.load(f, *args, **kwargs)
|
||||
return states
|
||||
|
||||
@staticmethod
|
||||
|
@ -267,9 +308,77 @@ class LocalClient(StorageClient):
|
|||
os.remove(fp)
|
||||
|
||||
|
||||
def get_tmp_file_name(tmp_local_folder: str, fp: str):
|
||||
"""
|
||||
It should be noted that all our temporary files will be stored in the same folder,
|
||||
so the file name passed upstream must be unique.
|
||||
"""
|
||||
base_path = os.path.join(tmp_local_folder, fp.split("/")[-1])
|
||||
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
|
||||
pid = os.getpid()
|
||||
# step = self.step_counter
|
||||
return "-".join([base_path, current_time, str(pid)]) + ".tmpfile" # , str(step)
|
||||
|
||||
|
||||
def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaInfo:
|
||||
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
|
||||
parts = fp.lstrip("s3://").split(os.path.sep)
|
||||
match = boto3_url_re.match(parts[0])
|
||||
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
||||
bucket_name, endpoint = match.group(1), match.group(2)
|
||||
endpoint = "http://" + endpoint + ":80"
|
||||
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
||||
return Boto3MetaInfo(
|
||||
is_async=is_async,
|
||||
handler=None,
|
||||
bucket_name=bucket_name,
|
||||
endpoint=endpoint,
|
||||
file_path=os.path.sep.join(parts[1:]),
|
||||
async_upload_fn=Boto3Client.async_upload_fileobj,
|
||||
local_nvme_path=tmp_step_file,
|
||||
)
|
||||
|
||||
|
||||
def get_local_meta(fp: str) -> LocalMetaInfo:
|
||||
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
||||
return LocalMetaInfo(None, fp)
|
||||
|
||||
|
||||
def get_mount_point_free_size(path: str):
|
||||
"""
|
||||
Returns the remaining space of the temporary storage mount point as a percentage.
|
||||
Args:
|
||||
path (str): temporary storage folder path.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the temporary storage folder does not exist,
|
||||
an error will be reported。
|
||||
"""
|
||||
if os.path.exists(path):
|
||||
st = os.statvfs(path)
|
||||
# f_bavail: Number of free blocks for unprivileged users.
|
||||
# f_bsize: Filesystem block size.
|
||||
# return unit is TB.
|
||||
return st.f_bavail * st.f_bsize / (1024**3)
|
||||
|
||||
|
||||
def check_tmp_folder_accessibility(tmp_local_folder: str):
|
||||
"""
|
||||
Check access permissions for temporary storage.
|
||||
"""
|
||||
ret = True
|
||||
if os.path.exists(tmp_local_folder):
|
||||
ret &= os.access(tmp_local_folder, os.W_OK)
|
||||
ret &= os.access(tmp_local_folder, os.R_OK)
|
||||
if ret is False:
|
||||
error_str = f'{socket.gethostname()} dose not have read and write permissions on {tmp_local_folder}"'
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
|
||||
class StorageManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Storage Manager for saving or loading checkpoint.
|
||||
TODO: add a thread to poll the asynchronous storage state.
|
||||
"""
|
||||
|
||||
BACKEND_TYPE = {"boto3", "local"}
|
||||
|
@ -279,8 +388,44 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
}
|
||||
CLI_DICT = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None:
|
||||
self._exception_list = []
|
||||
self._to_be_del_files = []
|
||||
self._async_stack = []
|
||||
self.upload_count = 0
|
||||
self.tmp_local_folder = tmp_local_folder
|
||||
self.async_mode = async_mode
|
||||
self.has_warning = False
|
||||
self._async_loop = None
|
||||
self._thread_pool = None
|
||||
self.latest_save_folder = None
|
||||
self.latest_save_step = 0
|
||||
self.async_task_peeding = False
|
||||
|
||||
if enable_save and self.async_mode:
|
||||
self._async_loop = asyncio.new_event_loop()
|
||||
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
|
||||
|
||||
check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder))
|
||||
|
||||
# Try to create tmp folder
|
||||
try:
|
||||
os.makedirs(self.tmp_local_folder, exist_ok=True)
|
||||
os.chmod(self.tmp_local_folder, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
# In case it is a directory created by other users, we check the permissions again.
|
||||
check_tmp_folder_accessibility(self.tmp_local_folder)
|
||||
|
||||
# Try to clean tmp folder's empty folder.
|
||||
self.try_delete_tmpfile(self.tmp_local_folder)
|
||||
|
||||
# Avaliable storeage space check.
|
||||
free_size = get_mount_point_free_size(self.tmp_local_folder)
|
||||
if free_size < 0.1:
|
||||
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
|
||||
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
|
||||
|
||||
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
||||
"""
|
||||
|
@ -301,7 +446,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
meta_info = get_local_meta(path)
|
||||
backend_key = backend
|
||||
elif backend == "boto3":
|
||||
meta_info = get_boto3_meta(path)
|
||||
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
|
||||
backend_key = backend + ":" + meta_info.endpoint
|
||||
init_args = (meta_info.endpoint,)
|
||||
if (
|
||||
|
@ -310,10 +455,12 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
or "HTTP_PROXY" in os.environ
|
||||
or "HTTPS_PROXY" in os.environ
|
||||
):
|
||||
raise RuntimeWarning(
|
||||
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
||||
the proxy may make boto3 unavailable or affect performance."
|
||||
)
|
||||
if not self.has_warning:
|
||||
logger.warning(
|
||||
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
||||
the proxy may make boto3 unavailable or affect performance."
|
||||
)
|
||||
self.has_warning = True
|
||||
|
||||
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
|
||||
|
||||
|
@ -333,19 +480,145 @@ the proxy may make boto3 unavailable or affect performance."
|
|||
meta = self._get_client(path=folder)
|
||||
return meta.client.get_fns(*unpack_meta(meta))
|
||||
|
||||
def save(self, save_path: str, saved_obj: Any, *args, **kwargs):
|
||||
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
|
||||
meta = self._get_client(path=save_path)
|
||||
|
||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
||||
|
||||
def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any:
|
||||
if async_upload is None:
|
||||
async_upload = self.async_mode
|
||||
if async_upload:
|
||||
assert (
|
||||
self.tmp_local_folder
|
||||
), "StorageManager is not setted tmp_local_folder, so async save cannot be performed."
|
||||
tmp_step_file = meta.local_nvme_path
|
||||
self._to_be_del_files.append(tmp_step_file)
|
||||
with open(tmp_step_file, "wb") as f:
|
||||
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
|
||||
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||
self.async_task_peeding = True
|
||||
else:
|
||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
||||
self.upload_count += 1
|
||||
|
||||
def load(self, load_path: str, *args, **kwargs) -> Any:
|
||||
self.wait()
|
||||
meta = self._get_client(path=load_path)
|
||||
return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs)
|
||||
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
|
||||
|
||||
def delete_obj(self, fp: str):
|
||||
meta = self._get_client(path=fp)
|
||||
meta.client.delete_obj(*unpack_meta(meta))
|
||||
|
||||
def _del_tmp_folder(self):
|
||||
for fp in self._to_be_del_files:
|
||||
try:
|
||||
os.remove(fp)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except SystemError as e:
|
||||
logger.error(f'delete file: {fp}, failed for reason:"{e}"')
|
||||
else:
|
||||
pass
|
||||
|
||||
storage_manager = StorageManager()
|
||||
def try_delete_tmpfile(self, tmp_dir: str):
|
||||
"""Delete temporary files in tmp_dir."""
|
||||
|
||||
for filename in os.listdir(tmp_dir):
|
||||
if filename.endswith(".tmpfile"):
|
||||
file_path = os.path.join(tmp_dir, filename)
|
||||
try:
|
||||
os.remove(file_path)
|
||||
logger.info(f"Delete tmpfile: {file_path}")
|
||||
except OSError:
|
||||
# Ignore deletion errors
|
||||
pass
|
||||
|
||||
async def _sync_tasks(self) -> Awaitable[None]:
|
||||
if self._async_stack:
|
||||
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
|
||||
count = 0
|
||||
while self._async_stack:
|
||||
t = self._async_stack[0]
|
||||
try:
|
||||
e = t.exception()
|
||||
if e:
|
||||
self._exception_list.append((e, count))
|
||||
logger.error(f"File:{self._to_be_del_files[count]}, upload failed for {e}")
|
||||
# raise e
|
||||
count += 1
|
||||
self._async_stack.pop(0)
|
||||
except InvalidStateError:
|
||||
# Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception
|
||||
pass
|
||||
|
||||
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
|
||||
"""
|
||||
Overview:
|
||||
Execute task in background, then apppend the future instance in _async_stack.
|
||||
Arguments:
|
||||
- fn (:obj:`Callable`): Synchronization fuction.
|
||||
"""
|
||||
if not self._async_loop:
|
||||
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode")
|
||||
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
|
||||
self._async_stack.append(t)
|
||||
|
||||
def wait(self) -> bool:
|
||||
"""Wait for async operations to complete."""
|
||||
|
||||
if not self.async_mode:
|
||||
return
|
||||
|
||||
if not self.async_task_peeding:
|
||||
return
|
||||
|
||||
if self._async_loop:
|
||||
self._async_loop.run_until_complete(self._sync_tasks())
|
||||
|
||||
if self._exception_list:
|
||||
for error_msg, file_id in self._exception_list:
|
||||
logger.error(
|
||||
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
|
||||
f"failed on step {self.upload_count}: {error_msg}"
|
||||
)
|
||||
|
||||
# TODO: Re-upload in sync mode
|
||||
raise RuntimeError(
|
||||
f"Failed to upload {self._to_be_del_files[file_id]} " f"on step {self.upload_count}: {error_msg}"
|
||||
)
|
||||
|
||||
self._del_tmp_folder()
|
||||
self._exception_list.clear()
|
||||
self._to_be_del_files.clear()
|
||||
self.async_task_peeding = False
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
self.upload_count += 1
|
||||
if self.async_mode:
|
||||
self.save(
|
||||
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
|
||||
saved_obj=dict({"step": self.latest_save_step}),
|
||||
async_upload=False,
|
||||
)
|
||||
|
||||
|
||||
storage_manager: StorageManager = None
|
||||
|
||||
|
||||
def init_storage_manager(ckpt_config):
|
||||
global storage_manager
|
||||
storage_manager = StorageManager(
|
||||
ckpt_config.enable_save_ckpt,
|
||||
tmp_local_folder=ckpt_config.async_upload_tmp_folder,
|
||||
async_mode=ckpt_config.async_upload,
|
||||
)
|
||||
|
||||
|
||||
def get_storage_manager():
|
||||
assert storage_manager is not None, "storage_manager has not been init!"
|
||||
return storage_manager
|
||||
|
||||
|
||||
def wait_async_upload_finish():
|
||||
dist.barrier()
|
||||
storage_manager.wait()
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
def tb_save_run_info(writer, config_lines, global_step=0):
|
||||
writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step)
|
||||
lines = []
|
||||
for line in config_lines:
|
||||
if line.strip().startswith("#"):
|
||||
continue
|
||||
lines.append(line)
|
||||
writer.add_text(tag="config", text_string="\n".join(lines), global_step=global_step)
|
||||
|
||||
|
||||
def init_tb_writer(
|
||||
job_name: str,
|
||||
launch_time: str,
|
||||
file_name: str,
|
||||
tensorboard_folder: str,
|
||||
resume_tb_folder: str,
|
||||
step_count: int,
|
||||
config: str,
|
||||
logger: logging.Logger,
|
||||
):
|
||||
tb_log_file_name = file_name
|
||||
if not tensorboard_folder:
|
||||
tb_folder = os.path.join(job_name, launch_time, "tensorboards")
|
||||
else:
|
||||
tb_folder = tensorboard_folder
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
# If we don't load ckpt, 'resume_tb_folder' is set as the tensorboard
|
||||
# dir of the last task by 'make_launch_script.sh'.
|
||||
# If we load ckpt, 'resume_tb_folder' will be overwritten as the
|
||||
# reloaded 'train_state.resume_tb_folder'.s
|
||||
if resume_tb_folder is not None:
|
||||
assert len(resume_tb_folder) > 0 and resume_tb_folder != "/"
|
||||
if not os.path.exists(resume_tb_folder):
|
||||
logger.error(
|
||||
f"Can't found resume_tb_folder{resume_tb_folder}, \
|
||||
please make sure this folder is located at local file system."
|
||||
)
|
||||
else:
|
||||
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}... ")
|
||||
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
|
||||
os.system(f"chmod -R +w {tb_folder}/")
|
||||
else:
|
||||
logger.info(f"Login tensorboard logs to: {tb_folder}")
|
||||
|
||||
tb_logdir = os.path.join(tb_folder, tb_log_file_name)
|
||||
writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3)
|
||||
writer.add_text(tag="job_name", text_string=job_name, global_step=step_count)
|
||||
writer.add_text(tag="tensorboard_folder", text_string=tb_logdir, global_step=step_count)
|
||||
|
||||
torch.distributed.broadcast_object_list([tb_folder], src=0)
|
||||
else:
|
||||
objects = [None]
|
||||
torch.distributed.broadcast_object_list(objects, src=0)
|
||||
tb_folder = objects[0]
|
||||
tb_logdir = os.path.join(tb_folder, tb_log_file_name)
|
||||
writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
tb_save_run_info(
|
||||
writer=writer,
|
||||
config_lines=config,
|
||||
global_step=step_count,
|
||||
)
|
||||
|
||||
writer.add_text(
|
||||
tag=f"mapping_{tb_log_file_name}",
|
||||
text_string=f"file_path={tb_logdir} hostname={socket.gethostname()} device={torch.cuda.current_device()}",
|
||||
global_step=step_count,
|
||||
)
|
||||
writer.add_scaler = partial(writer.add_scalar, new_style=True)
|
||||
|
||||
return writer, tb_logdir
|
||||
|
||||
|
||||
class Writer:
|
||||
"""
|
||||
Customed writer based on tensorboard for recording training metrics.
|
||||
|
||||
Args:
|
||||
job_name (str): The name of training job, defaults to None.
|
||||
launch_time (str): A string representing the launch time of the training.
|
||||
file_name (str): The log file name, defaults to None.
|
||||
tensorboard_folder (str): A string representing the folder for saving tensorboard logs.
|
||||
resume_tb_folder (str): A string representing the folder for resuming tensorboard logs.
|
||||
step_count (int): An integer representing the step count of the training.
|
||||
config (str): A string representing the configuration of the training.
|
||||
logger (logging.Logger): A logging.Logger object for logging information during training.
|
||||
enable_tb (bool): A boolean indicating whether to enable the tensorboard writer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job_name: str = None,
|
||||
launch_time: str = None,
|
||||
file_name: str = None,
|
||||
tensorboard_folder: str = None,
|
||||
resume_tb_folder: str = None,
|
||||
step_count: int = 0,
|
||||
config: str = None,
|
||||
logger: logging.Logger = None,
|
||||
enable_tb: bool = True,
|
||||
) -> None:
|
||||
self.enable_tb = enable_tb
|
||||
self.tb_writer, self.tb_logdir = init_tb_writer(
|
||||
job_name=job_name,
|
||||
launch_time=launch_time,
|
||||
file_name=file_name,
|
||||
tensorboard_folder=tensorboard_folder,
|
||||
resume_tb_folder=resume_tb_folder,
|
||||
step_count=step_count,
|
||||
config=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
def add_scalar(self, key, value, step):
|
||||
try:
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
def add_text(self, key, value, step):
|
||||
try:
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_text(tag=key, text_string=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
|
@ -30,7 +30,6 @@ $ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_
|
|||
```
|
||||
|
||||
可以通过运行以下命令来生成`bin`和`meta`文件:
|
||||
|
||||
```bash
|
||||
$ python tools/tokenizer.py --text_input_path raw_data.txt --bin_output_path cn/output.bin
|
||||
```
|
||||
|
|
|
@ -14,7 +14,6 @@ This directory provide some tools for model training with the following file str
|
|||
We need to use a `tokenizer` to generate `bin` and `meta` files for raw data. We import the tokenizer model by specifying the model weight path in `tools/tokenizer.py`. Currently, we provide `V7.model` to generate tokens. If you want to use a different model, you can modify the model weight path in `tokenizer.py` directly.
|
||||
|
||||
We can run the following command to generate `bin` and `meta` files corresponding to the original data. The parameter `text_input_path` represents the path of the original text data, currently supporting `txt`, `json`, and `jsonl` formats, while `bin_output_path` represents the save path of the generated `bin` files.
|
||||
|
||||
```bash
|
||||
$ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_path your_output_bin_path
|
||||
```
|
||||
|
|
586
train.py
586
train.py
|
@ -5,342 +5,76 @@ import socket
|
|||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.data.batch_sampler import StaticBatchSampler
|
||||
from internlm.data.collaters import packed_collate_fn
|
||||
from internlm.data.dummy_dataset import RandomDataset
|
||||
from internlm.data.packed_dataset import (
|
||||
PackedDataset,
|
||||
PackedDatasetWithoutCuSeqlen,
|
||||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.train import (
|
||||
get_train_data_loader,
|
||||
get_validation_data_loader,
|
||||
initialize_llm_profile,
|
||||
initialize_model,
|
||||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
)
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
get_master_node,
|
||||
get_megatron_flops,
|
||||
get_process_rank,
|
||||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import (
|
||||
load_context,
|
||||
load_model_checkpoint,
|
||||
load_optimizer_checkpoint,
|
||||
load_sampler,
|
||||
load_scheduler,
|
||||
save_checkpoint,
|
||||
)
|
||||
from internlm.utils.parallel import (
|
||||
is_no_pp_or_last_stage,
|
||||
sync_model_param,
|
||||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.simple_memory_profiler import (
|
||||
SimpleMemoryProfiler,
|
||||
build_activation_config,
|
||||
)
|
||||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
from internlm.utils.parallel import get_parallel_log_file_name
|
||||
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
|
||||
from internlm.utils.writer import Writer
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
|
||||
def initialize_llm_logger(start_time: str):
|
||||
"""
|
||||
Initialize distributed environment for distributed training.
|
||||
Initialize customed uniscale logger.
|
||||
|
||||
Args:
|
||||
config (str): Config file path.
|
||||
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
|
||||
master_port (str): The master port for distributed training. 8888 by default.
|
||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||
start_time (str): The launch time of current training job.
|
||||
|
||||
Returns: The instance of uniscale logger.
|
||||
"""
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if launcher == "torch":
|
||||
internlm.launch_from_torch(config=config, seed=seed)
|
||||
elif launcher == "slurm":
|
||||
internlm.launch_from_slurm(
|
||||
config=config,
|
||||
host=get_master_node(),
|
||||
port=master_port,
|
||||
seed=seed,
|
||||
)
|
||||
else:
|
||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||
|
||||
|
||||
def initialize_model():
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
Returns: The neural network model to be trained or evaluated.
|
||||
"""
|
||||
|
||||
assert (
|
||||
not hasattr(gpc.config.parallel, "pipeline") or gpc.config.parallel.pipeline == 1
|
||||
), "Pipeline parallelism is not supported for now."
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=is_no_pp_or_last_stage(),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
uniscale_logger = initialize_uniscale_logger(
|
||||
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
|
||||
)
|
||||
if uniscale_logger is not None:
|
||||
global logger
|
||||
logger = uniscale_logger
|
||||
|
||||
# This sync is very important, cause the model weights kept in optimizer are copied
|
||||
# from the origin parameters in the memory, so we should make sure the dp sync
|
||||
# does not influence the model weights in optimizer be different with the origin parameters.
|
||||
sync_model_param(model, parallel_mode=ParallelMode.DATA)
|
||||
|
||||
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
|
||||
# the same across tensor parallelism.
|
||||
sync_model_param_within_tp(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_train_data_loader(num_worker: int = 0):
|
||||
"""
|
||||
Generate and return the training data loader.
|
||||
|
||||
Returns: A tuple of (train_dl, dataset_types).
|
||||
"""
|
||||
|
||||
# Get the dataset types
|
||||
dataset_types = None
|
||||
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
# Get the sample weight dictionary
|
||||
train_folder = data_cfg.train_folder
|
||||
|
||||
if not train_folder:
|
||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||
if data_cfg.pack_sample_into_one:
|
||||
train_ds = PackedDatasetWithoutCuSeqlen(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = PackedDataset(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = get_packed_dataset_without_short_length(
|
||||
folder=data_cfg.train_folder,
|
||||
packed_length=data_cfg.packed_length,
|
||||
max_length_per_sample=data_cfg.seq_len,
|
||||
show_progress=dist.get_rank() == 0,
|
||||
min_length=data_cfg.min_length,
|
||||
min_length_dict=data_cfg.get("min_length_dict", {}),
|
||||
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
||||
)
|
||||
|
||||
# partition already completed
|
||||
# assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen))
|
||||
if isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)):
|
||||
datasets = [train_ds]
|
||||
else:
|
||||
datasets = train_ds.datasets
|
||||
|
||||
# Create the training dataset sampler
|
||||
train_sampler = StaticBatchSampler(
|
||||
datasets,
|
||||
batch_size=data_cfg.micro_num,
|
||||
rampup_batch_size=data_cfg.rampup_batch_size,
|
||||
micro_bsz=data_cfg.micro_bsz,
|
||||
seed=1024,
|
||||
drop_last=True,
|
||||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
|
||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||
|
||||
# Create the training data loader
|
||||
train_dl = DataLoader(
|
||||
dataset=train_ds,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=num_worker,
|
||||
pin_memory=True,
|
||||
collate_fn=train_collate_fn,
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
return train_dl, dataset_types
|
||||
|
||||
|
||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
|
||||
"""
|
||||
Load and return the new batch data based on training data loader.
|
||||
|
||||
Args:
|
||||
train_dl (torch.utils.data.DataLoader): Dataloader for training.
|
||||
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
train_state (TrainState): Current training state.
|
||||
|
||||
Returns: A batch data and the updated train_iter.
|
||||
"""
|
||||
|
||||
timer("batch-gen").start()
|
||||
try:
|
||||
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
|
||||
next(train_state.batch_sampler_iter)
|
||||
except StopIteration:
|
||||
train_iter = iter(train_dl)
|
||||
batch = next(train_iter)
|
||||
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
|
||||
next(train_state.batch_sampler_iter)
|
||||
train_state.num_consumed_samples_in_epoch = 0
|
||||
timer("batch-gen").stop()
|
||||
|
||||
batch[0].pop("type_ids", None)
|
||||
|
||||
return batch, train_iter
|
||||
|
||||
|
||||
def initialize_optimizer(model: nn.Module):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Your model instance to be trained or evaluated.
|
||||
|
||||
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
adam_cfg = gpc.config.adam
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
|
||||
lr=adam_cfg.lr,
|
||||
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
|
||||
eps=adam_cfg.adam_eps,
|
||||
)
|
||||
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
|
||||
)
|
||||
|
||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||
|
||||
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
|
||||
|
||||
return optimizer, beta2_scheduler, lr_scheduler
|
||||
|
||||
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
success_update,
|
||||
batch_count,
|
||||
batch,
|
||||
train_state,
|
||||
optimizer,
|
||||
beta2_scheduler,
|
||||
trainer,
|
||||
start_time,
|
||||
loss,
|
||||
grad_norm,
|
||||
):
|
||||
"""
|
||||
Print some training metrics of current batch.
|
||||
"""
|
||||
|
||||
if success_update in (0, True):
|
||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
if success_update and gpc.is_rank_for_log():
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
if hasattr(trainer.engine.optimizer, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
|
||||
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
|
||||
|
||||
num_tokens_in_batch = batch[1].nelement()
|
||||
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
|
||||
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
|
||||
tk_per_gpu = 0
|
||||
tk_per_gpu = round(
|
||||
num_tokens_in_batch
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
/ gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
/ (time.time() - start_time),
|
||||
2,
|
||||
)
|
||||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
|
||||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item(),
|
||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"lr": lr,
|
||||
"loss_scale": scaler,
|
||||
"grad_norm": grad_norm,
|
||||
}
|
||||
|
||||
infos["micro_num"] = len(batch[1])
|
||||
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
|
||||
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
|
||||
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
|
||||
infos["largest_length"] = max_length_in_batch # the longest input
|
||||
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
|
||||
infos["smallest_batch"] = min_samples_in_batch
|
||||
infos["adam_beta2"] = beta2_scheduler.get_beta2()
|
||||
|
||||
line = ""
|
||||
for k, v in infos.items():
|
||||
line += f"{k}={v},"
|
||||
|
||||
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
|
||||
line += f"fwd_bwd_time={fwd_bwd_time}"
|
||||
|
||||
logger.info(line)
|
||||
return uniscale_logger
|
||||
|
||||
|
||||
def main(args):
|
||||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
|
||||
# init setting
|
||||
skip_batches = gpc.config.data.skip_batches
|
||||
total_steps = gpc.config.data.total_steps
|
||||
load_optimizer = gpc.config.ckpt.load_optimizer
|
||||
valid_every = gpc.config.data.valid_every
|
||||
label_smoothing = gpc.config.loss.label_smoothing
|
||||
lr = gpc.config.adam.lr
|
||||
|
||||
# ckpt setting
|
||||
save_ckpt_folder = gpc.config.ckpt.save_ckpt_folder
|
||||
enable_save_ckpt = gpc.config.ckpt.enable_ckpt
|
||||
checkpoint_every = gpc.config.ckpt.checkpoint_every
|
||||
|
||||
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
|
||||
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
|
||||
|
||||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
checkpoint=gpc.config.model.checkpoint,
|
||||
|
@ -359,25 +93,8 @@ def main(args):
|
|||
dist.broadcast_object_list(objs, src=0)
|
||||
current_time = objs[0]
|
||||
|
||||
model_load_path = None
|
||||
if load_resume_ckpt_folder is not None:
|
||||
logger.info(
|
||||
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = load_resume_ckpt_folder
|
||||
elif load_model_only_folder is not None:
|
||||
logger.info(
|
||||
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = load_model_only_folder
|
||||
else:
|
||||
logger.info(
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},"
|
||||
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
)
|
||||
# initialize customed llm logger
|
||||
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config)
|
||||
|
@ -385,32 +102,66 @@ def main(args):
|
|||
# initialize model
|
||||
model = initialize_model()
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.alert_address,
|
||||
)
|
||||
|
||||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||
|
||||
# initialize the train data loader
|
||||
train_dl, _ = get_train_data_loader(num_worker=4)
|
||||
# initialize the train and validation data loader
|
||||
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
||||
val_dls = get_validation_data_loader()
|
||||
train_state.init_batch_sampler(train_dl)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
if model_load_path is not None:
|
||||
load_model_checkpoint(folder=model_load_path, model=model)
|
||||
ckpt_manager.try_load_model(current_time)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
# Loading other persistent training states.
|
||||
if load_resume_ckpt_folder is not None:
|
||||
# load lr scheduler states.
|
||||
load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
||||
# load training states.
|
||||
load_context(load_resume_ckpt_folder, train_dl, train_state)
|
||||
# load dataloader sampler states.
|
||||
load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler)
|
||||
# load optimzier states.
|
||||
if load_optimizer:
|
||||
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
||||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||
|
||||
# initialize customed llm writer
|
||||
writer = Writer(
|
||||
job_name=gpc.config.JOB_NAME,
|
||||
launch_time=current_time,
|
||||
file_name=get_parallel_log_file_name(),
|
||||
tensorboard_folder=gpc.config.tensorboard_folder,
|
||||
resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt.
|
||||
step_count=train_state.step_count, # resume from ckpt.
|
||||
config=config_lines,
|
||||
logger=logger,
|
||||
enable_tb=gpc.config.enable_tb,
|
||||
)
|
||||
|
||||
# initialize metric for calculating accuracy and perplexity
|
||||
metric = AccPerplex(
|
||||
device=torch.cuda.current_device(),
|
||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||
dataset_types=dataset_types,
|
||||
)
|
||||
|
||||
# initialize trainer
|
||||
scheduler_hooks = [
|
||||
SchedulerMetricHook(
|
||||
metric=metric,
|
||||
skip=(
|
||||
gpc.is_using_pp()
|
||||
and hasattr(gpc.config.model, "num_chunks")
|
||||
and gpc.config.model.num_chunks > 1
|
||||
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
|
@ -418,17 +169,17 @@ def main(args):
|
|||
train_dataloader=train_dl,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
|
||||
# initialize simple memory profiler
|
||||
if args.profiling:
|
||||
memory_profiler = SimpleMemoryProfiler(
|
||||
model.model,
|
||||
model,
|
||||
optimizer.optim,
|
||||
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
|
||||
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
|
||||
activation_config=build_activation_config(gpc.config.model.num_layers),
|
||||
)
|
||||
else:
|
||||
memory_profiler = None
|
||||
|
@ -441,89 +192,118 @@ def main(args):
|
|||
# transfer the train data loader into train data iterator
|
||||
train_iter = iter(train_dl)
|
||||
|
||||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
|
||||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
||||
# load batch data
|
||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
# load batch data
|
||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
|
||||
# record the consumed samples in training
|
||||
train_state.batch_count = batch_count
|
||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||
if batch_skipper(batch_count): # skip this batch
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Skip batch count:`{batch_count}`...")
|
||||
timer("one-batch").stop()
|
||||
continue
|
||||
# record the consumed samples in training
|
||||
train_state.batch_count = batch_count
|
||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||
if batch_skipper(batch_count): # skip this batch
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Skip batch count:`{batch_count}`...")
|
||||
timer("one-batch").stop()
|
||||
continue
|
||||
|
||||
# zero the grads of parameters
|
||||
trainer.zero_grad()
|
||||
# zero the grads of parameters
|
||||
trainer.zero_grad()
|
||||
# process data
|
||||
if batch[0].get("type_ids", None) is not None:
|
||||
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
|
||||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||
timer("fwd-bwd").stop()
|
||||
assert loss is not None
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
success_update, grad_norm = trainer_result
|
||||
if success_update: # update parameters successfully
|
||||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
record_current_batch_training_metrics(
|
||||
get_tflops_func=get_tflops_func,
|
||||
logger=logger,
|
||||
success_update=success_update,
|
||||
batch_count=batch_count,
|
||||
batch=batch,
|
||||
train_state=train_state,
|
||||
optimizer=optimizer,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
grad_norm=grad_norm,
|
||||
)
|
||||
success_update, grad_norm_groups = trainer_result
|
||||
if success_update: # update parameters successfully
|
||||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
message=f"Warning: skip parameter update at step {batch_count}.",
|
||||
)
|
||||
|
||||
timer("one-batch").stop()
|
||||
|
||||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
if enable_save_ckpt and train_state.step_count % checkpoint_every == 0:
|
||||
save_checkpoint(
|
||||
folder=save_ckpt_folder,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=lr_scheduler,
|
||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
record_current_batch_training_metrics(
|
||||
get_tflops_func=get_tflops_func,
|
||||
logger=logger,
|
||||
writer=writer,
|
||||
success_update=success_update,
|
||||
batch_count=batch_count,
|
||||
batch=batch,
|
||||
train_state=train_state,
|
||||
model_config=gpc.config.model,
|
||||
optimizer=optimizer,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
grad_norm=np.array(grad_norm_groups),
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
# wait for all checkpoint uploads to be completed
|
||||
dist.barrier()
|
||||
timer("one-batch").stop()
|
||||
|
||||
# evaluate on validation data loaders
|
||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||
evaluate_on_val_dls(
|
||||
trainer=trainer,
|
||||
val_dls=val_dls,
|
||||
writer=writer,
|
||||
logger=logger,
|
||||
step_count=train_state.step_count,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
now_break = ckpt_manager.try_save_checkpoint(train_state)
|
||||
if now_break:
|
||||
break
|
||||
|
||||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
|
||||
if batch_count % 2 == 0:
|
||||
prof.step()
|
||||
|
||||
ckpt_manager.wait_async_upload_finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
|
||||
traceback.print_exc()
|
||||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
|
||||
# initialize monitor manager context
|
||||
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
|
||||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
|
||||
)
|
||||
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
||||
|
|
64
web_demo.py
64
web_demo.py
|
@ -1,23 +1,20 @@
|
|||
"""
|
||||
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers. We mainly modified part of the code logic to adapt to the generation of our model.
|
||||
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
|
||||
We mainly modified part of the code logic to adapt to the generation of our model.
|
||||
Please refer to these links below for more information:
|
||||
1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
|
||||
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
|
||||
3. transformers: https://github.com/huggingface/transformers
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
import streamlit as st
|
||||
import torch
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional, Callable, Optional
|
||||
import copy
|
||||
import warnings
|
||||
import logging
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.utils import logging
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
|
||||
from tools.transformers.interface import generate_interactive, GenerationConfig
|
||||
from tools.transformers.interface import GenerationConfig, generate_interactive
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -25,9 +22,14 @@ logger = logging.get_logger(__name__)
|
|||
def on_btn_click():
|
||||
del st.session_state.messages
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
||||
.to(torch.bfloat16)
|
||||
.cuda()
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
@ -35,20 +37,12 @@ def load_model():
|
|||
def prepare_generation_config():
|
||||
with st.sidebar:
|
||||
max_length = st.slider("Max Length", min_value=32, max_value=2048, value=2048)
|
||||
top_p = st.slider(
|
||||
'Top P', 0.0, 1.0, 0.8, step=0.01
|
||||
)
|
||||
temperature = st.slider(
|
||||
'Temperature', 0.0, 1.0, 0.7, step=0.01
|
||||
)
|
||||
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
|
||||
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
|
||||
st.button("Clear Chat History", on_click=on_btn_click)
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_length=max_length,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
|
||||
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
|
||||
|
||||
return generation_config
|
||||
|
||||
|
||||
|
@ -74,16 +68,16 @@ def combine_history(prompt):
|
|||
|
||||
|
||||
def main():
|
||||
#torch.cuda.empty_cache()
|
||||
# torch.cuda.empty_cache()
|
||||
print("load model begin.")
|
||||
model, tokenizer = load_model()
|
||||
print("load model end.")
|
||||
|
||||
|
||||
user_avator = "doc/imgs/user.png"
|
||||
robot_avator = "doc/imgs/robot.png"
|
||||
|
||||
|
||||
st.title("InternLM-Chat-7B")
|
||||
|
||||
|
||||
generation_config = prepare_generation_config()
|
||||
|
||||
# Initialize chat history
|
||||
|
@ -106,22 +100,20 @@ def main():
|
|||
|
||||
with st.chat_message("robot", avatar=robot_avator):
|
||||
message_placeholder = st.empty()
|
||||
for cur_response in generate_interactive(model=model, tokenizer=tokenizer, prompt=real_prompt, additional_eos_token_id=103028, **asdict(generation_config)):
|
||||
for cur_response in generate_interactive(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=103028,
|
||||
**asdict(generation_config),
|
||||
):
|
||||
# Display robot response in chat message container
|
||||
message_placeholder.markdown(cur_response + "▌")
|
||||
message_placeholder.markdown(cur_response)
|
||||
# Add robot response to chat history
|
||||
st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue