Merge develop to main (#233)

* feat(utils/writer.py): support tensorboard writer (#63)

* feat(utils/writer.py): support tensorboard writer

* feat(utils/writer.py): add class comment

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* [Develop] Pull Main Branch (#121)

* fix/fix_submodule_err (#61)

* fix/fix_submodule_err

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>

* fix issue templates (#65)

* fix(tokenizer): refactor tokenizer and update usage in readme (#51)

* update tokenizer example

* fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (#73)

* fix a typo in readme

* in order to find InternLMTokenizer, select a lower version of Transformers

---------

Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>

* [Doc] Add wechat and discord link in readme (#78)

* Doc:add wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* [Docs]: add Japanese README (#43)

* Add Japanese README

* Update README-ja-JP.md

replace message

* Update README-ja-JP.md

* add repetition_penalty in GenerationConfig in web_demo.py (#48)

Co-authored-by: YWMditto <862779238@qq.com>

* use fp16 in instruction (#80)

* [Enchancement] add more options for issue template (#77)

* [Enchancement] add more options for issue template

* update qustion icon

* fix link

* Use tempfile for convert2hf.py (#23)

Fix https://github.com/InternLM/InternLM/issues/50

* delete torch_dtype of README's example code (#100)

* set the value of repetition_penalty to 1.0 to avoid random outputs (#99)

* Update web_demo.py (#97)

Remove meaningless log.

* [Fix]Fix wrong string cutoff in the script for sft text tokenizing (#106)

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com>
Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>
Co-authored-by: vansin <msnode@163.com>
Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com>
Co-authored-by: YWMditto <862779238@qq.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: Shuo Zhang <zhangshuolove@live.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* feat(core/scheduler): support pipeline parallel (#98)

* feat(utils/writer.py): support tensorboard writer

* feat(utils/writer.py): add class comment

* feat(core): support pipeline parallel

* fix(core): fix demo running error

* feat(solver/optimizer): add pp zero optimizer

* fix(solver/optimizer): fix word spelling error

* feat(core/scheduler): add new dir scheduler in core/

* fix(core): fix ci lint error

* feat(solver/optimizer): merge pp and nopp optimizer

* doc(usage.md): update usage doc

* feat(core/scheduler): support post func

* feat(core/scheduler): add dtype para in pp sche and update func get_tensor_shape

* feat(core/scheduler): add _load_micro_batch in base scheduler

* feat(core/scheduler): support optimizer overlap communication in pp scheduler

* feat(core/scheduler): delete data process func code

* feat(core/trainer): schedule pre processing for all schedule

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: huangting.p <huangting@sensetime.com>

* refactor(rotaryEmbedding): refactor forward (#120)

* use fp16 in instruction (#80)

* delete torch_dtype of README's example code (#100)

* refactor the forward for rotary embedding

---------

Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>

* feat(model/metrics.py): support calculating accuracy and perplexity m… (#91)

* feat(model/metrics.py): support calculating accuracy and perplexity metrics

* fix(model/metrics.py): fix import error

* feat(train.py): minor update

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: huangting.p <huangting@sensetime.com>

* fix(optimizer/util.py) change inf defination

* [Dev] Pull Main (#139)

* fix/fix_submodule_err (#61)

* fix/fix_submodule_err

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>

* fix issue templates (#65)

* fix(tokenizer): refactor tokenizer and update usage in readme (#51)

* update tokenizer example

* fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (#73)

* fix a typo in readme

* in order to find InternLMTokenizer, select a lower version of Transformers

---------

Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>

* [Doc] Add wechat and discord link in readme (#78)

* Doc:add wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* [Docs]: add Japanese README (#43)

* Add Japanese README

* Update README-ja-JP.md

replace message

* Update README-ja-JP.md

* add repetition_penalty in GenerationConfig in web_demo.py (#48)

Co-authored-by: YWMditto <862779238@qq.com>

* use fp16 in instruction (#80)

* [Enchancement] add more options for issue template (#77)

* [Enchancement] add more options for issue template

* update qustion icon

* fix link

* Use tempfile for convert2hf.py (#23)

Fix https://github.com/InternLM/InternLM/issues/50

* delete torch_dtype of README's example code (#100)

* set the value of repetition_penalty to 1.0 to avoid random outputs (#99)

* Update web_demo.py (#97)

Remove meaningless log.

* [Fix]Fix wrong string cutoff in the script for sft text tokenizing (#106)

* docs(install.md): update dependency package transformers version to >= 4.28.0 (#124)

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* docs(LICENSE): add license (#125)

* add license of colossalai and flash-attn

* fix lint

* modify the name

* fix AutoModel map in convert2hf.py (#116)

* variables are not printly as expect (#114)

* feat(solver): fix code to adapt to torch2.0 and provide docker images (#128)

* feat(solver): fix code to adapt to torch2.0

* docs(install.md): publish internlm environment image

* docs(install.md): update dependency packages version

* docs(install.md): update default image

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* add demo test (#132)

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* fix web_demo cache accelerate (#133)

* fix(hybrid_zero_optim.py): delete math import

* Update embedding.py

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com>
Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>
Co-authored-by: vansin <msnode@163.com>
Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com>
Co-authored-by: YWMditto <862779238@qq.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: Shuo Zhang <zhangshuolove@live.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: huangting4201 <1538303371@qq.com>
Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com>
Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>
Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com>

* style(solver/optimizer/utils.py): fix lint error (#147)

Co-authored-by: huangting.p <huangting@sensetime.com>

* feat(*): support not-flash-attn for pp and no-pp (#145)

* support not flash attention for no-pp

* support pipeline

* modify the config

* refactor the code

* refactor the code

* remove some unnecessary code

* fix(initialize/launch.py): set default value for use_flash_attn (#158)

* add default for use_flash_attn

* fix lint

* feat(utils/logger.py): support uniscale logger (#152)

* style(internlm): fix lint error

* feat(utils/logger.py): support uniscale logger

* fix(utils/logger.py): fix import circular error

* feat(train.py): support dashboard metric panel and fix ci train config

* fix(ci_scripts/train/slurm_train.sh): fix ci train error

* fix(ci_scripts/train/torchrun.sh): fix ci train error

* fix(ci_scripts/train): restore ci update

* fix(config.json): delete alert webhook

* feat(train.py): optimize func init logger

* feat(config.json): delete config.json

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: huangting.p <huangting@sensetime.com>

* feat(utils/evaluation.py): support evaluate (#154)

* style(internlm): fix lint error

* feat(utils/logger.py): support uniscale logger

* fix(utils/logger.py): fix import circular error

* feat(train.py): support dashboard metric panel and fix ci train config

* fix(ci_scripts/train/slurm_train.sh): fix ci train error

* fix(ci_scripts/train/torchrun.sh): fix ci train error

* feat(utils/evaluation.py): support evaluate on validation dataset

* fix(utils/evaluation.py): fix demo error

* fix(ci_scripts/train/ci_7B_sft.py): fix ci train error

* feat(initialize/launch.py): set default value for valid_bsz and valid_every

* fix(ci_scripts/train): restore ci update

* docs(configs/7B_sft.py): update comment for config

* fix(config.json): delete config.json

* fix evaluation bug in scheduler when use_flash_attn=False

* feat(scheduler/no_pipeline_scheduler.py): support micro_bsz>1 in no pp

* modify the jugement in pp and no-pp scheduler

* modify the data_process_func in evaluation

* fix bugs when use_flash_attn=False

* rename symbol

* feat(configs/7B_sft.py): change para valid_bsz to valid_micro_num

* feat(scheduler/no_pipeline_scheduler.py): update para set _grad_accum_batch_size

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: huangting.p <huangting@sensetime.com>
Co-authored-by: yingtongxiong <974106207@qq.com>

* feat(*): support no apex (#166)

* support no-apex

* add default for use_apex

* fix lint

* modify the RMSNormTorch

* remove some comments

* remove use_apex parameter

* remove some unnecessary code

* refactor(*): refactor the code with no-apex (#170)

* support no-apex

* add default for use_apex

* fix lint

* modify the RMSNormTorch

* remove some comments

* remove use_apex parameter

* remove some unnecessary code

* optimize the code including import

* remove the import RMSNorm

* remove warnings

* refactor(scheduler): rewrite pipeline scheduler (#138)

* refactor(scheduler): rewrite pipeline scheduler

* fix(*): fix pipeline scheduler bugs

* fix(*): fix merge bug

* feat(*): update codes with todo tag

* feat(*): add comments

* feat(internlm/core/scheduler): update recv_prev/next logic

* feat(utils/evaluation.py): update sche metric hook for valid

---------

Co-authored-by: huangting.p <huangting@sensetime.com>

* feat(*): support fp32 training (#155)

* support float32 training

* fix lint

* add adaptation in model/utils.py

* remove some unnecessary code

* fix lint

* feat(optim): add support for fp32 zero

* Revert "Merge pull request #2 from SolenoidWGT/fp32_zero"

This reverts commit 53fc50b0e5, reversing
changes made to 40f24d0a73.

revert commit

* merge develop

* Update utils.py

* support fp32 in zero optimizer

* modify the dtype

---------

Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>

* feat(*): support sequence_parallel (#180)

* support sequence_parallel for no pipeline

* sequence_parallel does not support no-flash-attn

* support sequence parallel for pipeline

* add memory profiler

* Update 13B.py

* add memory profiler

* fix evaluation bug

* remove some unnecessary code

* remove some unnecessary code

* Update parallel_context.py

* modify the config

* remove memory profiler

* modify the config

* support selective dropout

* feat(monitor): support monitor and alert (#175)

* feat(monitor): support monitor and alert

* feat(monitor.py): fix demo error

* feat(monitor.py): move cmd monitor args to config file

* feat(hybrid_zero_optim.py): if overflow occurs send alert msg

* feat(monitor.py): remove alert msg filter

* feat(monitor.py): optimize class MonitorTracker

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(monitor.py): optimize code

* feat(train.py): update print to log

* style(ci): fix lint error

* fix(utils/evaluation.py): remove useless code

* fix(model/modeling_internlm.py): fix lint error

---------

Co-authored-by: huangting4201 <huangting3@sensetime.com>

* feat(ckpt): add async upload and ckpt snapshot (#161)

* use fp16 in instruction (#80)

* delete torch_dtype of README's example code (#100)

* feat(ckpt): support async ckpt upload and ckpt snapshot

---------

Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>

* feat(ckpt): add auto ckpt load and singal quit (#189)

Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>

* Revert "feat(ckpt): add auto ckpt load and singal quit (#189)" (#192)

This reverts commit a45a91bb84.

* refactor(solver/optimizer): improve optimizer memory (#193)

* refactor(solver/optimizer): improve optimizer memory

* feat(data): remove useless dataset type ids map

* Feat/optimizer (#194)

* feat(optimier.py): reduce memory footprint and avoid _check_overflow call

* feat(optimier.py): reduce memory footprint and avoid _check_overflow call

* feat(optimizer.py): overlap compute norm with allreduce

* update var and function name

* update function compute norm (#197)

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>

* feat(optimizer/hybrid_zero_optim.py): overlap gradients last bucket allreduce and compute norm (#196)

* support gradients allreduce and compute norm overlap

* fix para set error

* remove timer cal_norm for testing

* feat(optimizer/hybrid_zero_optim.py): support group global norm

* format(lint): fix lint error

* feat(optimizer/store.py): update code based on comment

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: huangting4201 <1538303371@qq.com>

* fix(ci): fix ci train error (#199)

* fix/ci train error (#200)

* fix(ci): fix ci train error

* fix(ci): fix ci train error

* fix(ci): fix ci train error

* fix(train.py): fix scheduler metric hook skip error (#204)

* Merge main to develop (#203)

* fix/fix_submodule_err (#61)

* fix/fix_submodule_err

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>

* fix issue templates (#65)

* fix(tokenizer): refactor tokenizer and update usage in readme (#51)

* update tokenizer example

* fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (#73)

* fix a typo in readme

* in order to find InternLMTokenizer, select a lower version of Transformers

---------

Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>

* [Doc] Add wechat and discord link in readme (#78)

* Doc:add wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* [Docs]: add Japanese README (#43)

* Add Japanese README

* Update README-ja-JP.md

replace message

* Update README-ja-JP.md

* add repetition_penalty in GenerationConfig in web_demo.py (#48)

Co-authored-by: YWMditto <862779238@qq.com>

* use fp16 in instruction (#80)

* [Enchancement] add more options for issue template (#77)

* [Enchancement] add more options for issue template

* update qustion icon

* fix link

* Use tempfile for convert2hf.py (#23)

Fix https://github.com/InternLM/InternLM/issues/50

* delete torch_dtype of README's example code (#100)

* set the value of repetition_penalty to 1.0 to avoid random outputs (#99)

* Update web_demo.py (#97)

Remove meaningless log.

* [Fix]Fix wrong string cutoff in the script for sft text tokenizing (#106)

* docs(install.md): update dependency package transformers version to >= 4.28.0 (#124)

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* docs(LICENSE): add license (#125)

* add license of colossalai and flash-attn

* fix lint

* modify the name

* fix AutoModel map in convert2hf.py (#116)

* variables are not printly as expect (#114)

* feat(solver): fix code to adapt to torch2.0 and provide docker images (#128)

* feat(solver): fix code to adapt to torch2.0

* docs(install.md): publish internlm environment image

* docs(install.md): update dependency packages version

* docs(install.md): update default image

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* add demo test (#132)

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* fix web_demo cache accelerate (#133)

* Doc: add twitter link (#141)

* Feat add checkpoint fraction (#151)

* feat(config): add checkpoint_fraction into config

* feat: remove checkpoint_fraction from configs/7B_sft.py

---------

Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>

* [Doc] update deployment guide to keep consistency with lmdeploy (#136)

* update deployment guide

* fix error

* use llm partition (#159)

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* test(ci_scripts): clean test data after test, remove unnecessary global variables, and other optimizations (#165)

* test: optimization of ci scripts(variables, test data cleaning, etc).

* chore(workflows): disable ci job on push.

* fix: update partition

* test(ci_scripts): add install requirements automaticlly,trigger event about lint check and other optimizations (#174)

* add pull_request in lint check

* use default variables in ci_scripts

* fix format

* check and install requirements automaticlly

* fix format

---------

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* feat(profiling): add a simple memory profiler (#89)

* feat(profiling): add simple memory profiler

* feat(profiling): add profiling argument

* feat(CI_workflow): Add PR & Issue auto remove workflow (#184)

* feat(ci_workflow): Add PR & Issue auto remove workflow

Add a workflow for stale PR & Issue  auto remove
- pr & issue well be labeled as stale for inactive in 7 days
- staled PR & Issue  well be remove in 7 days
- run this workflow every day on 1:30 a.m.

* Update stale.yml

* feat(bot): Create .owners.yml for Auto Assign (#176)

* Create .owners.yml: for issue/pr assign automatically

* Update .owners.yml

* Update .owners.yml

fix typo

* [feat]: add pal reasoning script (#163)

* [Feat] Add PAL inference script

* Update README.md

* Update tools/README.md

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update tools/pal_inference.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update pal script

* Update README.md

* restore .ore-commit-config.yaml

* Update tools/README.md

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update tools/README.md

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update pal inference script

* Update READMD.md

* Update internlm/utils/interface.py

Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>

* Update pal script

* Update pal script

* Update script

* Add docstring

* Update format

* Update script

* Update script

* Update script

---------

Co-authored-by: BigDong <yudongwang1226@gmail.com>
Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>

* test(ci_scripts): add timeout settings and clean work after the slurm job (#185)

* restore pr test on develop branch

* add mask

* add post action to cancel slurm job

* remove readonly attribute on job log

* add debug info

* debug job log

* try stdin

* use stdin

* set default value avoid error

* try setting readonly on job log

* performance echo

* remove debug info

* use squeue to check slurm job status

* restore the lossed parm

* litmit retry times

* use exclusive to avoid port already in use

* optimize loop body

* remove partition

* add {} for variables

* set env variable for slurm partition

---------

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* refactor(tools): move interface.py and import it to web_demo (#195)

* move interface.py and import it to web_demo

* typo

* fix(ci): fix lint error

* fix(ci): fix lint error

---------

Co-authored-by: Sun Peng <sunpengsdu@gmail.com>
Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com>
Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>
Co-authored-by: vansin <msnode@163.com>
Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com>
Co-authored-by: YWMditto <862779238@qq.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: Shuo Zhang <zhangshuolove@live.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com>
Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>
Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com>
Co-authored-by: Guoteng <32697156+SolenoidWGT@users.noreply.github.com>
Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: zachtzy <141206206+zachtzy@users.noreply.github.com>
Co-authored-by: cx <759046501@qq.com>
Co-authored-by: Jaylin Lee <61487970+APX103@users.noreply.github.com>
Co-authored-by: del-zhenwu <dele.zhenwu@gmail.com>
Co-authored-by: Shaoyuan Xie <66255889+Daniel-xsy@users.noreply.github.com>
Co-authored-by: BigDong <yudongwang1226@gmail.com>
Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Co-authored-by: huangting4201 <huangting3@sensetime.com>

* fix(pipeline_scheduler.py): fix tensor shape err and comm block (#210)

* feat(train.py): support torch profiler (#201)

* feat(train.py): support torch profiling

* feat(train.py): optimize initialize_llm_profile

* feat(train.py): profiling with tp0 and dp0

* move sequence parallel context manager to evalation func

* fix lint

* move the process for type_ids to load_new_batch

* fix lint

---------

Co-authored-by: yingtongxiong <974106207@qq.com>

* feat(ckpt): add auto ckpt load and singal quit (#216)

Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>

* feat(memory_profiler): improve memory profiler (#217)

* Feat/overlap_bcast_forward (#218)

* feat/support bcast forward overlao

* feat/optimize the bcast call

* feat/optimize the bcast call

* feat/optimize the bcast call

* fix lint

* fix lint

* fix lint

* fix lint

* add torch.cuda.synchronize in save_checkpoint

---------

Co-authored-by: sunpeng <sunpengsdu@gmail.com>

* fix(*): move sequence_parallel to parallel config (#224)

* move sequence_parallel to parallel config

* set the sequece_parallel default value is False

* fix lint

* fix lint

* fix lint

* Feat/example training internlm (#212)

* feat(train/training_internlm.py): move common init funcs to internlm/train

* feat(train/training_internlm.py): update some public funcs

* feat(train/training_internlm.py): update some public funcs

* feat(evaluation.py): adapt evaluate to streaming dataset

* feat(train/training_internlm.py): minor update based on comments

* fix(training_internlm.py): set train dataloader persistent_workers true only when num_worker>0

* fix(training_internlm.py): fix demo error

* feat(data/utils.py): add new dataset type code for streaming dataset (#225)

* test(model): support fp32 with flash_attn (#223)

* support tf32 with flash

* move autocast to attention

* fix lint

* fix lint

* fix lint

* fix lint

* fix some bugs in model

* modify the convert dtype

* fix(pipeline): modify the sequence_parallel in pipeline (#227)

* move sequence_parallel to parallel config

* set the sequece_parallel default value is False

* fix lint

* fix lint

* fix lint

* modify the sequence_parallel in pp

* feat(init): add skip args check flag and add zero overlap flag (#222)

* feat(init): add skip args check flag

* fix(optim): add param overlap enable flag

* fix(ci): fix train error (#228)

Co-authored-by: huangting4201 <huangting3@sensetime.com>

* fix(writer): fix tensorboard resume bug (#229)

* fix(train.py): fix overflow grad norm error (#230)

* feat(ckpt): add train config into ckpt (#231)

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: Sun Peng <sunpengsdu@gmail.com>
Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com>
Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>
Co-authored-by: vansin <msnode@163.com>
Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com>
Co-authored-by: YWMditto <862779238@qq.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: Shuo Zhang <zhangshuolove@live.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: huangting.p <huangting@sensetime.com>
Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com>
Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>
Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com>
Co-authored-by: yingtongxiong <974106207@qq.com>
Co-authored-by: cx <759046501@qq.com>
Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>
Co-authored-by: huangting4201 <huangting3@sensetime.com>
Co-authored-by: Guoteng <32697156+SolenoidWGT@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: zachtzy <141206206+zachtzy@users.noreply.github.com>
Co-authored-by: Jaylin Lee <61487970+APX103@users.noreply.github.com>
Co-authored-by: del-zhenwu <dele.zhenwu@gmail.com>
Co-authored-by: Shaoyuan Xie <66255889+Daniel-xsy@users.noreply.github.com>
Co-authored-by: BigDong <yudongwang1226@gmail.com>
Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
pull/238/head v0.2.0
huangting4201 2023-08-24 22:03:04 +08:00 committed by GitHub
parent e1cefaef6b
commit 54f85a6e9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 6108 additions and 1315 deletions

View File

@ -1,5 +1,5 @@
name: demo-in-readme
on:
on:
pull_request:
branches:
- "main"
@ -110,7 +110,6 @@ jobs:
srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --gpus-per-task=2 python ../ci_scripts/model/loaded_as_transformer.py
cd ..
rm -rf $GITHUB_WORKSPACE/hf_ckpt
load-chat-model-in-hf:
if: ${{ always() }}
needs: check-requirements

4
.gitignore vendored
View File

@ -115,6 +115,7 @@ venv.bak/
*.pkl
*.pkl.json
*.log.json
*.trace.json
docs/modelzoo_statistics.md
mmdet/.mim
work_dirs/
@ -142,4 +143,5 @@ core.*
# Run
llm_ckpts
memory_trace
events.*
memory_trace

View File

@ -49,5 +49,5 @@ repos:
args:
[
'--rcfile=.pylintrc',
'--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703'
'--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203'
]

View File

@ -15,6 +15,7 @@ VOCAB_SIZE = 103168
SAVE_CKPT_FOLDER = "local:llm_ckpts"
# LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
ckpt = dict(
enable_save_ckpt=True,
# Path to save training ckpt.
save_ckpt_folder=SAVE_CKPT_FOLDER,
# Path to continue training ckpt (load model weights and scheduler/context states).

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40"
readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt"
expected_num=21
expected_num=22
exit_code=0
source ./ci_scripts/common/basic_func.sh

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
readonly CKPTS20_OUTPUT="${CKPTS20_PATH}/*.pt"
expected_num=21
expected_num=22
exit_code=0
source ./ci_scripts/common/basic_func.sh

View File

@ -5,7 +5,7 @@ set -x
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
readonly CKPTS_OUTPUT="${CKPTS20_PATH}/*.pt"
expected_num=21
expected_num=22
exit_code=0
source ./ci_scripts/common/basic_func.sh

View File

@ -7,31 +7,43 @@ MLP_RATIO = 8 / 3
NUM_LAYER = 32
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
# oss: 'boto3:s3://model_weights/XXX'
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
# Path to save training ckpt.
save_ckpt_folder=SAVE_CKPT_FOLDER,
# Path to continue training ckpt (load model weights and scheduler/context states).
# load_ckpt_folder=LOAD_CKPT_FOLDER,
# Path to initialize with given model weights.
# load_model_only_folder=MODEL_ONLY_FOLDER,
checkpoint_every=50,
# Wheter to load optimizer states when continuing training.
load_optimizer=True,
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/path/to/dataset"
VALID_FOLDER = "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
@ -39,6 +51,7 @@ data = dict(
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
)
grad_scaler = dict(
@ -62,7 +75,8 @@ grad_scaler = dict(
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
zero_overlap_communication=True,
overlap_sync_grad=True,
overlap_sync_param=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
@ -107,9 +121,11 @@ model = dict(
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
)
"""
zero1 parallel:
@ -118,11 +134,15 @@ zero1 parallel:
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel: pipeline parallel size, only 1 is accepted currently.
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=8,
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False,
)
cudnn_deterministic = False

View File

@ -174,7 +174,7 @@ parallel = dict(
- When `size <= 0`, the size of the zero1 process group is equal to the size of the data parallel process group, so the optimizer state parameters will be split within the data parallel range.
- When `size == 1`, zero1 is not used, and all data parallel groups retain the complete optimizer state parameters.
- When `size > 1` and `size <= data_parallel_world_size`, the zero1 process group is a subset of the data parallel process group.
- pipeline: pipeline parallel size, currently only supports 1, default value is 1
- pipeline: pipeline parallel size, default value is 1
- tensor: tensor parallel size, usually the number of GPUs per node, default value is 1
Note: `Data parallel size = Total number of GPUs / Pipeline parallel size / Tensor parallel size`

View File

@ -159,7 +159,7 @@ parallel = dict(
- 当`size <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配
- 当`size == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数
- 当`size > 1`且`size <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集
- pipeline流水线并行大小目前只支持 1默认值为 1
- pipeline流水线并行大小默认值为 1
- tensor张量并行大小通常是每个节点的 GPU 数量,默认值为 1
注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小`

View File

@ -0,0 +1,32 @@
from .p2p import (
AsynCommunicator,
recv_backward,
recv_forward,
send_backward,
send_backward_and_recv_next_backward_async,
send_backward_recv_backward,
send_backward_recv_forward,
send_forward,
send_forward_and_recv_next_forward_async,
send_forward_backward_recv_forward_backward,
send_forward_recv_backward,
send_forward_recv_forward,
)
from .utils import recv_obj_meta, send_obj_meta
__all__ = [
"send_forward",
"send_forward_recv_forward",
"send_forward_backward_recv_forward_backward",
"send_backward",
"send_backward_recv_backward",
"send_backward_recv_forward",
"send_forward_recv_backward",
"recv_backward",
"recv_forward",
"send_obj_meta",
"recv_obj_meta",
"send_backward_and_recv_next_backward_async",
"send_forward_and_recv_next_forward_async",
"AsynCommunicator",
]

View File

@ -0,0 +1,582 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
import operator
from functools import reduce
from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
TensorShape = Union[torch.Size, List[int], Tuple[int]]
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
Args:
tensor_shape (:class:`torch.Size`): shape of tensor
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
Returns:
Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
"""
if chunk_tensor:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
if tensor_chunk_shape % tensor_parallel_world_size == 0:
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
else:
tensor_chunk_shape = tensor_shape
chunk_tensor = False
else:
tensor_chunk_shape = tensor_shape
return tensor_chunk_shape, chunk_tensor
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
if isinstance(recv_shapes, torch.Size):
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
return buffer_recv, recv_split
buffer_recv = []
for recv_shape in recv_shapes:
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
buffer_recv.append(tensor_recv)
return buffer_recv, recv_split
def process_object_to_send(object_send, scatter_gather_tensors):
if isinstance(object_send, torch.Tensor):
send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
if send_split:
object_send = split_tensor_into_1d_equal_chunks(object_send)
return object_send
object_send_list = []
for tensor_send in object_send:
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
if send_split:
object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
else:
object_send_list.append(tensor_send)
object_send = tuple(object_send_list)
return object_send
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
if isinstance(obj, torch.Tensor):
op_to_add = dist.P2POp(comm_op, obj, comm_rank)
ops_queue.append(op_to_add)
else:
for tensor_to_comm in obj:
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
ops_queue.append(op_to_add)
def _communicate(
object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
recv_prev: bool = False,
recv_next: bool = False,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
prev_rank: int = None,
next_rank: int = None,
dtype: torch.dtype = None,
scatter_gather_tensors: bool = False,
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule.
Takes the following arguments:
object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank
(no tensor sent if set to None).
object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank
(no tensor sent if set to None).
recv_prev (bool): boolean for whether tensor should be received from
previous rank.
recv_next (bool): boolean for whether tensor should be received from
next rank.
recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
from the previous stage, defualts to None.
recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
from the next stage, defualts to None.
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
next_rank (int): the rank of the next pipeline stage, defualts to None,
dtype (torch.dtype): data type of intermediate buffers, defaults to None
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
Returns:
Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
"""
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
if recv_prev:
assert recv_prev_shape is not None
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
recv_prev_shape, dtype, scatter_gather_tensors
)
if recv_next:
assert recv_next_shape is not None
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
recv_next_shape, dtype, scatter_gather_tensors
)
if object_send_prev is not None or recv_prev:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
if object_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
if object_send_prev is not None:
object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
if object_send_next is not None:
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
ops = []
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = (
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
)
if recv_next and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = (
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
)
return tensor_recv_prev, tensor_recv_next
def recv_forward(
input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
to be received.
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
"""
input_tensor, _ = _communicate(
recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor
def recv_backward(
output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
to be received.
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
"""
_, output_tensor_grad = _communicate(
recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
"""Sends the input tensor to the next stage in pipeline.
Args:
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
prev_rank (int, optional): The rank of the recipient of the tensor
"""
_communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors)
def send_forward_recv_backward(
output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage.
Args:
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
to be received.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
"""
_, output_tensor_grad = _communicate(
object_send_next=output_tensor,
recv_next=output_grad_shape is not None,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad
def send_backward_recv_forward(
input_tensor_grad,
input_tensor_shape,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
Args:
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
to be received.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
"""
input_tensor, _ = _communicate(
object_send_prev=input_tensor_grad,
recv_prev=input_tensor_shape is not None,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor
def send_forward_recv_forward(
output_tensor,
input_tensor_shape,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
Args:
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
to be received.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
"""
input_tensor, _ = _communicate(
object_send_next=output_tensor,
recv_prev=input_tensor_shape is not None,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor
def send_backward_recv_backward(
input_tensor_grad,
output_grad_shape,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the gradient tensor from the
next member in pipeline as the input of this stage.
Args:
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
to be received.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
"""
_, output_tensor_grad = _communicate(
object_send_prev=input_tensor_grad,
recv_next=output_grad_shape is not None,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
the gradient tensor to the previous stage, while receives the input gradient tensor from the
next stage and the input tensor from the previous stage.
Args:
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
from the previous.
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
from the next.
Returns:
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`,
List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
"""
input_tensor, output_tensor_grad = _communicate(
object_send_next=output_tensor,
object_send_prev=input_tensor_grad,
recv_prev=input_tensor_shape is not None,
recv_next=output_grad_shape is not None,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor, output_tensor_grad
def send_forward_and_recv_next_forward_async(
output_tensor,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
dtype: torch.dtype = None,
scatter_gather_tensors=False,
):
"""send forward output to next rank and recv forward input from prev rank"""
reqs = []
tensor_recv_prev = None
# prepare send opreations
if output_tensor is not None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors)
if isinstance(output_tensor, torch.Tensor):
reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank))
else:
for tensor_to_comm in output_tensor:
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank))
# prepare receive opreations
if recv_prev_shape is not None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
# create receive buffer
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
recv_prev_shape, dtype, scatter_gather_tensors
)
# generate async receive opterations
if isinstance(tensor_recv_prev, torch.Tensor):
reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank))
else:
for tensor_to_comm in tensor_recv_prev:
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank))
if len(reqs) > 0:
reqs = dist.batch_isend_irecv(reqs)
# return and do other things
yield
# check communication completed
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv()
torch.cuda.synchronize()
# Process received data
if recv_prev_shape is not None and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = (
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
)
yield tensor_recv_prev
def send_backward_and_recv_next_backward_async(
input_tensor,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
dtype: torch.dtype = None,
scatter_gather_tensors=False,
):
reqs = []
tensor_recv_next = None
# prepare send opreations
if input_tensor is not None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors)
if isinstance(input_tensor, torch.Tensor):
reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank))
else:
for tensor_to_comm in input_tensor:
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank))
# prepare receive opreations
if recv_next_shape is not None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
# create receive buffer
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
recv_next_shape, dtype, scatter_gather_tensors
)
# generate async receive opreations
if isinstance(tensor_recv_next, torch.Tensor):
reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank))
else:
for tensor_to_comm in tensor_recv_next:
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank))
if len(reqs) > 0:
reqs = dist.batch_isend_irecv(reqs)
# return and do other things
yield
# check communication completed
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv()
torch.cuda.synchronize()
# Process received data
if recv_next_shape is not None and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = (
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
)
yield tensor_recv_next
class AsynCommunicator:
"""AsynCommunicator for managing async communication."""
def __init__(
self,
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
recv_shape: Union[torch.Size, List[torch.Size]],
dtype: torch.dtype = None,
scatter_gather_tensors=False,
forward: bool = True,
) -> None:
self._need_receive = recv_shape is not None
if forward:
self._coroutine = send_forward_and_recv_next_forward_async(
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
)
else:
self._coroutine = send_backward_and_recv_next_backward_async(
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
)
@property
def need_receive(self) -> bool:
return self._need_receive
def start(self) -> None:
next(self._coroutine)
def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]:
received = next(self._coroutine)
self._coroutine.close()
return received

View File

@ -0,0 +1,125 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device
TensorShape = Union[torch.Size, List[int], Tuple[int]]
def send_meta_helper(obj, next_rank, tensor_kwargs):
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
dist.send(send_ndims, next_rank)
dist.send(send_shape, next_rank)
def send_obj_meta(obj, next_rank=None):
"""Sends obj meta information before sending a specific obj.
Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be sent before communications. This function
synchronizes with :func:`recv_obj_meta`.
Args:
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
need_meta (bool, optional): If False, meta information won't be sent.
next_rank (int): The rank of the next member in pipeline parallel group.
Returns:
bool: False
"""
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
send_meta_helper(obj, next_rank, tensor_kwargs)
else:
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
for tensor_to_send in obj:
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
def recv_meta_helper(prev_rank, tensor_kwargs):
recv_ndims = torch.empty((), **tensor_kwargs)
dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.recv(recv_shape, prev_rank)
return recv_shape
def recv_obj_meta(prev_rank=None) -> torch.Size:
"""Receives obj meta information before receiving a specific obj.
Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be received before communications. This function
synchronizes with :func:`send_obj_meta`.
Args:
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
prev_rank (int): The rank of the source of the obj.
Returns:
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
"""
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1:
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape = torch.Size(recv_shape)
else:
obj_shape = []
for _ in range(recv_obj_nums.item()):
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape.append(torch.Size(recv_shape))
return obj_shape
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
"""Break a tensor into equal 1D chunks.
Args:
tensor (:class:`torch.Tensor`): Tensor to be split before communication.
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
Returns:
:class:`torch.Tensor`: The split tensor
"""
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR)
start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR)
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Opposite of above function, gather values from model parallel ranks.
Args:
tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
Returns:
:class:`torch.Tensor`: The gathered tensor.
"""
world_size = gpc.get_world_size(ParallelMode.TENSOR)
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
return gathered

View File

@ -464,7 +464,6 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
for initializer in initializers:
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):

View File

@ -73,6 +73,17 @@ class NaiveAMPModel(nn.Module):
input_ = input_.float()
return input_
def convert_to_fp32(self, out):
"""Converts the output to fp32"""
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
return out
def _reduce_module_buffer(self):
"""
All-reduces the buffers (e.g., running stats of batch normalization) across
@ -121,10 +132,5 @@ class NaiveAMPModel(nn.Module):
out = self.model(*args, **kwargs)
if self._output_to_fp32:
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
out = self.convert_to_fp32(out)
return out

View File

@ -1,279 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable
import torch
from internlm.core.engine import Engine
from internlm.utils.common import conditional_context
class BaseScheduler(ABC):
"""A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and
optimizer_step for parameters update.
For the convenience to enable FP16, we aggregate all codes that contain the
control of FP16 in class schedule.
Args:
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
them into data and label.
"""
def __init__(self, data_process_func: Callable = None):
self.data_process_func = data_process_func
@abstractmethod
def pre_processing(self, engine: Engine):
"""To perform actions before running the schedule.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
"""
pass
@abstractmethod
def forward_backward_step(
self,
engine: Engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function over a batch of dataset for training or evaluation.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
forward_only (bool): If True, the process won't include backward.
return_loss (bool, optional): If False, the loss won't be returned.
return_output_label (bool, optional): If False, the output and label won't be returned.
"""
pass
@staticmethod
def _call_engine(engine: Engine, inputs: Any):
"""Calls the engine with the given inputs.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
"""
if isinstance(inputs, torch.Tensor):
return engine(inputs)
elif isinstance(inputs, (list, tuple)):
return engine(*inputs)
elif isinstance(inputs, dict):
return engine(**inputs)
else:
raise TypeError(
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
)
@staticmethod
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
"""Calls the engine's criterion with the given outputs and labels.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
"""
assert isinstance(
outputs, (torch.Tensor, list, tuple, dict)
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(labels, torch.Tensor):
labels = (labels,)
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
return engine.criterion(*outputs, *labels)
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
return engine.criterion(*outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, dict):
return engine.criterion(**outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
else:
raise TypeError(
f"Expected model outputs and labels to be of type torch.Tensor ' \
'(which is auto-converted to tuple), list, tuple, or dict, ' \
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
)
class NonPipelineScheduler(BaseScheduler):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
Args:
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
and returns a tuple in the form of (data, label), and it will be executed in load_batch.
gradient_accumulation_steps(int, optional): the steps of gradient accumulation, 1 for disable
gradient accumulation.
Example:
# this shows an example of customized data_process_func
def data_process_func(dataloader_output):
item1, item2, item3 = dataloader_output
data = (item1, item2)
label = item3
return data, label
"""
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
# check that non-pipeline schedule data process func only takes in one parameter
# which is the batch data
if data_process_func:
sig = inspect.signature(data_process_func)
assert len(sig.parameters) == 1, (
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
"which is a tuple of tensors for the current batch, "
"i.e. data_process_func(dataloader_output)."
)
self._grad_accum_size = gradient_accumulation_size
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
self._grad_accum_offset = 0
super().__init__(data_process_func)
def pre_processing(self, engine: Engine):
"""Performs actions before running the schedule.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
"""
pass
def _load_accum_batch(self, data: Any, label: Any):
"""Loads a batch of data and label for gradient accumulation.
Args:
data (Any): The data to be loaded.
label (Any): The label to be loaded.
"""
_data = {
k: v[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
for k, v in data.items()
}
_label = label[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
self._grad_accum_offset += self._grad_accum_batch_size
return _data, _label
def _train_one_batch(
self,
data: Any,
label: Any,
engine: Engine,
forward_only: bool = False,
return_loss: bool = True,
scale_loss: int = 1,
):
"""Trains one batch of data.
Args:
data (Any): The data to be trained.
label (Any): The label for the data.
engine (internlm.core.Engine): InternLM engine for training and inference.
forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
be executed.
return_loss (bool, optional): Loss will be returned if True.
scale_loss (int, optional): The scale factor for the loss.
"""
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
output = self._call_engine(engine, data)
if return_loss:
loss = self._call_engine_criterion(engine, output, label)
loss /= scale_loss
# backward
if not forward_only:
engine.backward(loss)
if not return_loss:
loss = None
return output, loss
def forward_backward_step(
self,
engine: Engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
If True, the model is run for the forward pass, else back propagation will be executed.
return_loss (bool, optional): Loss will be returned if True.
return_output_label (bool, optional): Output and label will be returned if True.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data, batch_size = engine.load_batch(data_iter)
assert (
batch_size == self._grad_accum_size
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
if self.data_process_func:
data, label = self.data_process_func(batch_data)
else:
# if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label)
data, label = batch_data
loss = 0 if return_loss else None
outputs = []
labels = []
# reset accumulation microbatch offset
self._grad_accum_offset = 0
for _current_accum_step in range(self._grad_accum_size):
if _current_accum_step == self._grad_accum_size - 1:
engine.optimizer.skip_grad_reduce = False
else:
engine.optimizer.skip_grad_reduce = True
_data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
)
if return_loss:
loss += _loss
if return_output_label:
outputs.append(_output)
labels.append(_label)
if not return_output_label:
outputs, labels = None, None
return outputs, labels, loss

View File

@ -0,0 +1,12 @@
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
from .no_pipeline_scheduler import NonPipelineScheduler
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
__all__ = [
"BaseScheduler",
"NonPipelineScheduler",
"InterleavedPipelineScheduler",
"PipelineScheduler",
"SchedulerHook",
"SchedulerMetricHook",
]

View File

@ -0,0 +1,187 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, Optional
import torch
from internlm.core.engine import Engine
from internlm.utils.megatron_timers import megatron_timer as timer
class BaseScheduler(ABC):
"""A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and
optimizer_step for parameters update.
For the convenience to enable FP16, we aggregate all codes that contain the
control of FP16 in class schedule.
Args:
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
them into data and label.
"""
def __init__(self, data_process_func: Callable = None):
self.data_process_func = data_process_func
@abstractmethod
def pre_processing(self, engine: Engine):
"""To perform actions before running the schedule.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
"""
pass
def _load_micro_batch(self, data, label, offset, micro_bsz):
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()}
micro_batch_label = label[offset : offset + micro_bsz]
return micro_batch_data, micro_batch_label
@abstractmethod
def forward_backward_step(
self,
engine: Engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function over a batch of dataset for training or evaluation.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
forward_only (bool): If True, the process won't include backward.
return_loss (bool, optional): If False, the loss won't be returned.
return_output_label (bool, optional): If False, the output and label won't be returned.
"""
pass
@staticmethod
def _call_engine(engine: Engine, inputs: Any):
"""Calls the engine with the given inputs.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
"""
if isinstance(inputs, torch.Tensor):
return engine(inputs)
elif isinstance(inputs, (list, tuple)):
return engine(*inputs)
elif isinstance(inputs, dict):
return engine(**inputs)
else:
raise TypeError(
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
)
@staticmethod
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
"""Calls the engine's criterion with the given outputs and labels.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
"""
assert isinstance(
outputs, (torch.Tensor, list, tuple, dict)
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(labels, torch.Tensor):
labels = (labels,)
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
return engine.criterion(*outputs, *labels)
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
return engine.criterion(*outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, dict):
return engine.criterion(**outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
else:
raise TypeError(
f"Expected model outputs and labels to be of type torch.Tensor ' \
'(which is auto-converted to tuple), list, tuple, or dict, ' \
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
)
class SchedulerHook(ABC):
"""
Scheduler Hook.
"""
@abstractmethod
def before_forward(self, scheduler, inputs) -> None:
"""Actions before forward"""
@abstractmethod
def after_forward(self, scheduler, outputs) -> None:
"""Actions after forward"""
@abstractmethod
def before_criterion(self, scheduler, outputs, label) -> None:
"""Actions before criterion"""
@abstractmethod
def after_criterion(self, scheduler, loss) -> None:
"""Actions after criterion"""
@abstractmethod
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
"""Actions before backward"""
@abstractmethod
def after_backward(self, scheduler, inputs_grad) -> None:
"""Actions after backward"""
@abstractmethod
def post_helper_func(self, scheduler, outputs, label) -> None:
"""A post helper function"""
class SchedulerMetricHook(SchedulerHook):
"""
Scheduler Metric Hook.
"""
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
self._post_func = metric
self._skip = skip
def before_forward(self, scheduler, inputs) -> None:
if not self._skip:
timer("fwd").start()
def after_forward(self, scheduler, outputs) -> None:
if not self._skip:
timer("fwd").stop()
def before_criterion(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("cal_loss").start()
def after_criterion(self, scheduler, loss) -> None:
if not self._skip:
timer("cal_loss").stop()
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if not self._skip:
timer("bwd").start()
def after_backward(self, scheduler, inputs_grad) -> None:
if not self._skip:
timer("bwd").stop()
def post_helper_func(self, scheduler, outputs, label) -> None:
if self._post_func is not None:
self._post_func(outputs, label)

View File

@ -0,0 +1,192 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from typing import Any, Callable, Iterable, List, Optional
import torch
from internlm.core.engine import Engine
from internlm.utils.common import conditional_context
from .base_scheduler import BaseScheduler, SchedulerHook
class NonPipelineScheduler(BaseScheduler):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
Args:
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
and returns a tuple in the form of (data, label), and it will be executed in load_batch.
gradient_accumulation_steps(int, optional): the steps of gradient accumulation, 1 for disable
gradient accumulation.
Example:
# this shows an example of customized data_process_func
def data_process_func(dataloader_output):
item1, item2, item3 = dataloader_output
data = (item1, item2)
label = item3
return data, label
"""
def __init__(
self,
data_process_func: Callable = None,
gradient_accumulation_size: int = 1,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
):
self._grad_accum_size = gradient_accumulation_size
self._grad_accum_offset = 0
self._hooks = scheduler_hooks
super().__init__(data_process_func)
def pre_processing(self, engine: Engine):
"""Performs actions before running the schedule.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
"""
pass
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
for hook in self._hooks:
getattr(hook, func_name)(self, *args, **kwargs)
def _load_accum_batch(self, data: Any, label: Any):
"""Loads a batch of data and label for gradient accumulation.
Args:
data (Any): The data to be loaded.
label (Any): The label to be loaded.
"""
_data, _label = self._load_micro_batch(
data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
)
self._grad_accum_offset += self._grad_accum_batch_size
if self.data_process_func:
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
_label = self.data_process_func(_label, _data["cu_seqlens"])
_data.pop("cu_seqlens")
_data.pop("indexes")
return _data, _label
def _train_one_batch(
self,
data: Any,
label: Any,
engine: Engine,
forward_only: bool = False,
return_loss: bool = True,
scale_loss: int = 1,
):
"""Trains one batch of data.
Args:
data (Any): The data to be trained.
label (Any): The label for the data.
engine (internlm.core.Engine): InternLM engine for training and inference.
forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
be executed.
return_loss (bool, optional): Loss will be returned if True.
scale_loss (int, optional): The scale factor for the loss.
"""
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data)
output = self._call_engine(engine, data)
self._call_hooks("after_forward", output)
self._call_hooks("post_helper_func", output, label)
if return_loss:
self._call_hooks("before_criterion", output, label)
loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
loss /= scale_loss
# backward
if not forward_only:
self._call_hooks("before_backward", None, None)
engine.backward(loss)
self._call_hooks("after_backward", None)
if not return_loss:
loss = None
return output, loss
def forward_backward_step(
self,
engine: Engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
If True, the model is run for the forward pass, else back propagation will be executed.
return_loss (bool, optional): Loss will be returned if True.
return_output_label (bool, optional): Output and label will be returned if True.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data, batch_size = engine.load_batch(data_iter)
assert (
batch_size % self._grad_accum_size == 0
), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}"
self._grad_accum_batch_size = batch_size // self._grad_accum_size
data, label = batch_data
loss = 0 if return_loss else None
outputs = []
labels = []
# reset accumulation microbatch offset
self._grad_accum_offset = 0
for _current_accum_step in range(self._grad_accum_size):
if _current_accum_step == self._grad_accum_size - 1:
engine.optimizer.skip_grad_reduce = False
else:
engine.optimizer.skip_grad_reduce = True
_data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
)
if return_loss:
loss += _loss
if return_output_label:
outputs.append(_output)
labels.append(_label)
if not return_output_label:
outputs, labels = None, None
return outputs, labels, loss

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,12 @@ import json
from typing import Iterable, Optional
from internlm.core.engine import Engine
from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler
from internlm.core.scheduler import (
BaseScheduler,
InterleavedPipelineScheduler,
NonPipelineScheduler,
PipelineScheduler,
)
class TrainState:
@ -33,6 +38,11 @@ class TrainState:
# Total step count
self.total_steps: int = config.data.total_steps
# resume tensorboard folder, need load from checkpoint or set manually.
self.resume_tb_folder = config.resume_tb_folder
self.tensorboard_folder = config.tensorboard_folder
def init_batch_sampler(self, train_dl):
# Copy of the batch sampler from the DataLoader
self.batch_sampler = train_dl.batch_sampler.copy()
@ -71,6 +81,9 @@ class TrainState:
self.batch_sampler = train_dl.batch_sampler.copy()
self.batch_sampler_iter = iter(self.batch_sampler)
# resume tensorboard from older tensorboard_folder
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
def state_dict(self):
return {
"batch_count": self.batch_count,
@ -78,6 +91,7 @@ class TrainState:
"num_consumed_tokens": self.num_consumed_tokens,
"inf_nan_skip_batches": self.inf_nan_skip_batches,
"step_count": self.step_count,
"tensorboard_folder": self.tensorboard_folder,
}
@ -112,8 +126,7 @@ class Trainer:
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
self._schedule = schedule
if self.uses_pipeline:
self._schedule.pre_processing(self)
self._schedule.pre_processing(self._engine)
@property
def engine(self):
@ -126,7 +139,7 @@ class Trainer:
@property
def uses_pipeline(self):
"""Returns whether the pipeline parallel is used or not."""
return False
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
def train(self):
self._engine.train()

View File

@ -219,11 +219,6 @@ class StaticBatchSampler:
assert (
batch_size - self.start_bsz
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
assert (
self.start_bsz // micro_bsz >= 4
), f"Must have more start samples:`{self.start_bsz}` with micro_bsz:\
`{micro_bsz}`, so that the pipeline can run correctly"
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
assert (
self.start_bsz % micro_bsz == 0

56
internlm/data/dataset.py Normal file
View File

@ -0,0 +1,56 @@
import os
from typing import Dict
from torch.utils.data import ConcatDataset
from internlm.data.single_dataset import JsonlDataset
def get_dataset_dict(folder, split="valid") -> Dict:
"""
Return a dictionary of Datasets from a folder containing data files for validation.
Args:
folder (str): The path to the folder containing data files.
split (str): The split of the data files to be used, default is "valid".
Returns:
A dictionary containing Datasets for each folder in the given path
that contains data files with the specified split.
Raises:
AssertionError: If the given folder does not exist.
Example:
If the given folder is as follows,
- data
- zhihu
- xxx.bin
- valid.bin
- baike
- xxx.bin
- valid.bin
The returned dictionary will be,
{
'zhihu': Dataset,
'baike': Dataset
}
"""
assert os.path.exists(folder), f"folder `{folder}` not exists"
data_dict = {}
for root, dirs, files in os.walk(folder, followlinks=True):
dirs.sort() # The order is guaranteed, and the newly added data starting with z needs to be ranked behind
datasets = []
for fn in sorted(files): # Need sorted to ensure that the order is consistent
if fn.endswith(".bin") and split in fn:
fp = os.path.join(root, fn)
ds = JsonlDataset(fp)
datasets.append(ds)
if datasets:
ds = ConcatDataset(datasets=datasets)
data_dict[os.path.basename(root)] = ds
return data_dict

View File

@ -144,6 +144,48 @@ class PackedDataset(torch.utils.data.Dataset):
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
return out
def cal_pos_unpack(self, index):
if index == 0:
pre_pos = 0
else:
pre_pos = index * gpc.config.data["micro_bsz"]
pos = (index + 1) * gpc.config.data["micro_bsz"]
return pre_pos, pos
def build_unpack(self, index):
pre_pos, pos = self.cal_pos_unpack(index)
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
while pre_pos < pos and pre_pos < len(self.dataset):
sample_idx = self.sample_indices[pre_pos]
sample = self.dataset[sample_idx]
length = min(len(sample["tokens"]), self.max_length_per_sample)
chunk = sample["tokens"][0:length]
pack.extend(chunk)
_labels = deepcopy(chunk)
_labels = list(_labels[1:]) + [-100]
assert len(_labels) == len(chunk), (_labels, chunk)
labels.extend(_labels)
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
cu_seqlens.append(cu_seqlens[-1] + len(chunk))
indexes.extend(list(range(length)))
pre_pos = pre_pos + 1
if cu_seqlens[-1] != self.packed_length:
pack = pack + [0] * (self.packed_length - cu_seqlens[-1])
labels = labels + [0] * (self.packed_length - cu_seqlens[-1])
type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1])
indexes.extend(list(range(self.packed_length - cu_seqlens[-1])))
cu_seqlens.append(self.packed_length)
assert len(pack) == self.packed_length
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
return out
def __getitem__(self, item: int) -> Dict:
"""Given the index, it returns a dict as
{
@ -154,8 +196,11 @@ class PackedDataset(torch.utils.data.Dataset):
}
"""
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
if gpc.config.model.use_flash_attn:
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
return self.build_unpack(item)
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):

View File

@ -1,7 +1,11 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5}
import torch
from internlm.core.context import global_context as gpc
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2}
def get_dataset_type_id(path):
@ -13,3 +17,30 @@ def get_dataset_type_id(path):
match_idxes.append(idx)
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
return match_idxes[0]
def unpack_data(input_ids, cu_seqlens):
"""
input_ids: (n, packed_length)
Return:
output: (batch_size, max_length)
"""
bsz = input_ids.shape[0]
num_sequence = gpc.config.data["micro_bsz"]
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
for i in range(bsz):
output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
cu_seqlens_slice = cu_seqlens[i]
for j in range(num_sequence):
seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
outputs[i] = output
if bsz == 1:
outputs = outputs.squeeze(0)
return outputs

View File

@ -1,9 +1,15 @@
from .initialize_trainer import initialize_trainer
from .launch import get_default_parser, launch_from_slurm, launch_from_torch
from .launch import (
get_default_parser,
initialize_distributed_env,
launch_from_slurm,
launch_from_torch,
)
__all__ = [
"get_default_parser",
"initialize_trainer",
"launch_from_slurm",
"launch_from_torch",
"initialize_distributed_env",
]

View File

@ -3,7 +3,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
from typing import Callable, Iterable, Optional, Tuple
from typing import Callable, Iterable, List, Optional, Tuple
from torch import nn
from torch.nn.modules.loss import _Loss
@ -11,11 +11,19 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
from internlm.core.no_pipeline_scheduler import NonPipelineScheduler
from internlm.core.scheduler import (
InterleavedPipelineScheduler,
NonPipelineScheduler,
PipelineScheduler,
SchedulerHook,
)
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
from internlm.core.trainer import Trainer
from internlm.data.utils import unpack_data
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
from internlm.utils.common import get_current_device
@ -29,6 +37,7 @@ def initialize_trainer(
test_dataloader: Optional[Iterable] = None,
lr_scheduler: Optional[_LRScheduler] = None,
beta2_scheduler: Optional[Beta2Scheduler] = None,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
@ -59,6 +68,8 @@ def initialize_trainer(
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
# gradient handler, only support PipelineSharedModuleGradientHandler now
if gpc.is_using_pp():
gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
gradient_handlers = []
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
@ -67,8 +78,50 @@ def initialize_trainer(
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
gradient_handlers.append(handler)
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
# initialize scheduler for trainer
scheduler = None
if gpc.config.model.use_flash_attn:
data_fn = None
else:
data_fn = unpack_data
if gpc.is_using_pp():
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
tensor_shape = get_tensor_shape()
use_interleaved = (
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
)
scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
if use_interleaved:
if isinstance(model, nn.Sequential):
model = nn.ModuleList([model])
communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
scheduler = InterleavedPipelineScheduler(
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
num_chunks=gpc.config.model.num_chunks,
dtype=gpc.config.model["dtype"],
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather,
scheduler_hooks=scheduler_hooks,
communication_overlap=communication_overlap,
)
else:
scheduler = PipelineScheduler(
data_process_func=data_fn,
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
dtype=gpc.config.model["dtype"],
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather,
scheduler_hooks=scheduler_hooks,
)
else:
scheduler = NonPipelineScheduler(
data_process_func=data_fn,
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
scheduler_hooks=scheduler_hooks,
)
# initialize engine for trainer
engine = Engine(
model=model,
optimizer=optimizer,

View File

@ -10,7 +10,9 @@ import torch
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import init_storage_manager
logger = get_logger(__file__)
@ -38,7 +40,7 @@ def get_default_parser():
parser.add_argument("--local_rank", type=int, help="local rank on the node")
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
parser.add_argument("--seed", type=int, default=1024)
parser.add_argument("--profiling", default=False, action="store_true", help="enable/diable profiling.")
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
return parser
@ -89,6 +91,12 @@ def args_sanity_check():
if "valid_folder" not in data:
data._add_item("valid_folder", None)
if "valid_micro_num" not in data:
data._add_item("valid_micro_num", data.micro_num)
if "valid_every" not in data:
data._add_item("valid_every", 0)
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"seq_len: {data.seq_len}")
@ -97,36 +105,104 @@ def args_sanity_check():
logger.info(f"packed_length: {data.packed_length}")
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
logger.info(f"min_length: {data.min_length}")
logger.info(f"valid_micro_num: {data.valid_micro_num}")
logger.info(f"valid_every: {data.valid_every}")
# processing the checkpoint config
if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0:
gpc.config.ckpt._add_item("checkpoint_every", float("inf"))
ckpt = gpc.config.ckpt
if "enable_save_ckpt" not in ckpt:
ckpt._add_item("enable_save_ckpt", False)
if "load_optimizer" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_optimizer", True)
# Saving checkpoint args.
if ckpt.enable_save_ckpt:
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
assert ckpt.checkpoint_every > 0
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
if "save_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("save_ckpt_folder", None)
if "async_upload" not in ckpt:
ckpt._add_item("async_upload", False) # async defalut is False.
else:
if ckpt.async_upload:
assert "save_ckpt_folder" in ckpt
if "boto3:" not in ckpt.save_ckpt_folder:
if gpc.is_rank_for_log():
logger.warning(
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
)
ckpt.async_upload = False
else:
if "async_upload_tmp_folder" not in ckpt:
ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
if "load_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_ckpt_folder", None)
if not ckpt.async_upload:
ckpt._add_item("async_upload_tmp_folder", None)
if "load_model_only_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_model_only_folder", None)
if "snapshot_ckpt_folder" not in ckpt:
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
assert not (
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
if "oss_snapshot_freq" not in ckpt:
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
else:
ckpt._add_item("checkpoint_every", float("inf"))
ckpt._add_item("oss_snapshot_freq", float("inf"))
ckpt._add_item("save_ckpt_folder", None)
ckpt._add_item("async_upload", False)
ckpt._add_item("async_upload_tmp_folder", None)
ckpt._add_item("snapshot_ckpt_folder", None)
ckpt._add_item("snapshot_ckpt_folder", None)
gpc.config.ckpt._add_item(
"enable_ckpt", gpc.config.ckpt.save_ckpt_folder is not None and gpc.config.ckpt.checkpoint_every > 0
)
# Loading checkpoint args.
if "load_model_only_folder" not in ckpt:
ckpt._add_item("load_model_only_folder", None)
if "load_ckpt_folder" not in ckpt:
ckpt._add_item("load_ckpt_folder", None)
if "load_optimizer" not in ckpt:
ckpt._add_item("load_optimizer", True)
if "stop_file_path" not in ckpt:
ckpt._add_item("stop_file_path", None)
if "load_given_ckpt" not in ckpt:
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
# to auto-load latest checkpoint.
ckpt._add_item("load_given_ckpt", False)
if ckpt.load_given_ckpt:
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
logger.warning(
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
)
ckpt.load_model_only_folder = None
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_ckpt}")
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
# initialization storage manager
init_storage_manager(ckpt)
# tensorboard writer config
if "enable_tb" not in gpc.config:
gpc.config._add_item("enable_tb", True)
if "tensorboard_folder" not in gpc.config:
gpc.config._add_item(
"tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None
)
if "resume_tb_folder" not in gpc.config:
gpc.config._add_item(
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
)
if gpc.is_rank_for_log():
logger.info(f"tensorboard_folder: {gpc.config.tensorboard_folder}")
logger.info(f"resume_tb_folder: {gpc.config.resume_tb_folder}")
# cudnn
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
@ -144,12 +220,24 @@ def args_sanity_check():
logger.warning("dtype is not set, use torch.float16 by defalut!")
model._add_item("dtype", torch.float16)
else:
if model.dtype == "torch.bfloat16":
model.dtype = torch.bfloat16
elif model.dtype in ("torch.float16", "torch.half"):
model.dtype = torch.float16
if gpc.config.model.dtype == "torch.bfloat16":
gpc.config.model.dtype = torch.bfloat16
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
gpc.config.model.dtype = torch.float16
elif gpc.config.model.dtype == "torch.float32":
gpc.config.model.dtype = torch.float32
elif gpc.config.model.dtype == "torch.tf32":
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
gpc.config.model.dtype = torch.float32
else:
assert model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
assert gpc.config.model.dtype in [
"torch.float16",
"torch.half",
"torch.bfloat16",
"torch.float32",
"torch.tf32",
]
if "checkpoint" in model:
if model.checkpoint is True:
@ -177,6 +265,35 @@ def args_sanity_check():
logger.info("+" * 15 + " beta2_scheduler Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"beta2_scheduler: {gpc.config.beta2_scheduler}")
# process the model config
if "use_flash_attn" not in gpc.config.model:
gpc.config.model._add_item("use_flash_attn", True)
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)
else:
assert not (
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting
if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None)
optim_ckpt = gpc.config.hybrid_zero_optimizer
if "zero_overlap_communication" in optim_ckpt:
# Compatible with the old interfaces.
optim_ckpt._add_item("overlap_sync_grad", optim_ckpt.zero_overlap_communication)
if "overlap_sync_grad" not in optim_ckpt:
optim_ckpt._add_item("overlap_sync_grad", False)
if "overlap_sync_param" not in optim_ckpt:
optim_ckpt._add_item("overlap_sync_param", False)
if gpc.is_rank_for_log():
logger.info(
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
)
def launch(
config: Union[str, Path, Config, Dict],
@ -223,8 +340,6 @@ def launch(
# init process groups for different parallel modes from config
gpc.init_parallel_groups()
args_sanity_check()
# set cuda device
if torch.cuda.is_available():
# if local rank is not given, calculate automatically
@ -277,7 +392,11 @@ def launch_from_slurm(
)
def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024):
def launch_from_torch(
config: Union[str, Path, Config, Dict],
backend: str = "nccl",
seed: int = 1024,
):
"""A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
@ -305,3 +424,38 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nc
backend=backend,
seed=seed,
)
def initialize_distributed_env(
config: str,
launcher: str = "slurm",
master_port: int = 8888,
seed: int = 1024,
args_check=True,
):
"""
Initialize distributed environment for distributed training.
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
torch.cuda.empty_cache()
if launcher == "torch":
launch_from_torch(config=config, seed=seed)
elif launcher == "slurm":
launch_from_slurm(
config=config,
host=get_master_node(),
port=master_port,
seed=seed,
)
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
if args_check:
args_sanity_check()

View File

@ -3,6 +3,7 @@
from .embedding import Embedding1D, RotaryEmbedding
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
from .metrics import AccPerplex
from .modeling_internlm import build_model_with_cfg
from .multi_head_attention import MHA
from .utils import gather_forward_split_backward
@ -13,6 +14,7 @@ __all__ = [
"RotaryEmbedding",
"RewardModelLinear",
"ScaleColumnParallelLinear",
"AccPerplex",
"MHA",
"gather_forward_split_backward",
"build_model_with_cfg",

View File

@ -7,13 +7,14 @@ import rotary_emb
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
from torch import Tensor, nn
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from .utils import gather_forward_split_backward
from .utils import gather_forward_split_backward, split_forward_gather_backward
class Embedding1D(nn.Module):
@ -56,6 +57,9 @@ class Embedding1D(nn.Module):
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
if gpc.config.parallel.sequence_parallel:
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
return output
@ -108,6 +112,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply
class RotaryEmbedding(torch.nn.Module):
@ -176,7 +181,15 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, qkv: torch.Tensor, **kwargs):
if kwargs.get("indexes", None) is not None:
return self._forward(qkv, kwargs.pop("indexes"))
if kwargs.get("inference_params", None) is not None:
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
else:
return self._eval_forward(qkv)
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv, indexes)
if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
@ -189,7 +202,7 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_k_cached[indexes],
)
def eval_forward(self, qkv, seqlen_offset=0):
def _eval_forward(self, qkv, seqlen_offset=0):
"""
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.

View File

@ -5,15 +5,13 @@ from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import (
ColumnParallelLinear,
RowParallelLinear,
fused_dense_func,
)
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import fused_dense_func_torch
class ScaleColumnParallelLinear(nn.Linear):
@ -40,7 +38,6 @@ class ScaleColumnParallelLinear(nn.Linear):
out_features: int,
process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_scale: int = 1,
@ -50,7 +47,6 @@ class ScaleColumnParallelLinear(nn.Linear):
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.weight_scale = weight_scale
def forward(self, input): # pylint: disable=W0622
@ -61,8 +57,12 @@ class ScaleColumnParallelLinear(nn.Linear):
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return fused_dense_func(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
return fused_dense_func_torch(
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
)
@ -89,12 +89,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
out_features: int,
process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_scale: int = 1,
) -> None:
super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale)
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
if bias:
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
@ -107,11 +106,37 @@ class RewardModelLinear(ScaleColumnParallelLinear):
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return fused_dense_func(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
return fused_dense_func_torch(
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
)
class ColumnParallelLinearTorch(ColumnParallelLinear):
def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return fused_dense_func_torch(
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
)
class RowParallelLinearTorch(RowParallelLinear):
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out = fused_dense_func_torch(x, self.weight, self.bias)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class FeedForward(nn.Module):
"""
FeedForward.
@ -143,24 +168,30 @@ class FeedForward(nn.Module):
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
self.w1 = ColumnParallelLinearTorch(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = ColumnParallelLinear(
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype
self.w2 = ColumnParallelLinearTorch(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = RowParallelLinear(
self.w3 = RowParallelLinearTorch(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)

263
internlm/model/metrics.py Normal file
View File

@ -0,0 +1,263 @@
from typing import List
import torch
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
from torch_scatter import scatter
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.parallel import is_no_pp_or_last_stage
class AccPerplex:
"""
AccPerplex module for calculating model's accuracy and perplexity metrics.
Args:
device: The GPU device.
tp_pg: The tensor parallel process group.
dp_pg: The data parallel process group.
tokenizer: For calculating BPB.
dataset_types (List[str]): Various data types that will be used in the current training process,
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
"""
def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
self.device = device
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device)
self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
self.tokenizer = tokenizer
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
self.batch_shift = 0
self.type_ids = None
if dataset_types is not None:
self.dataset_types = dataset_types
self.total_type_count = len(dataset_types)
self.ds_right = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types)
def set_current_type_ids(self, type_ids: torch.Tensor):
self.batch_shift = 0
self.type_ids = type_ids.cuda()
def __call__(self, logits, labels):
return self.update(logits, labels, type_ids=self.type_ids)
def update(self, logits, labels, type_ids=None):
if gpc.config.model.use_flash_attn:
micro_bsz = labels.size(0)
else:
micro_bsz = 1
if type_ids is not None:
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
self.batch_shift += 1
self.loss_with_type_id.update(logits, labels, type_ids)
with torch.no_grad():
if isinstance(logits, (list, tuple)):
logits = logits[0]
logits = logits.detach().clone()
labels = labels.detach().clone()
if self.tokenizer: # need to calculate bits per bytes
sequences = self.tokenizer.decode_ids(labels.tolist())
self.total_bytes += sum(map(lambda x: len(x.encode("utf-8")), sequences))
shift_logits = logits.view(-1, logits.size(-1))
shift_labels = labels.view(-1)
# There is a shift according to the current rank, because the logits are split
pred_shift = self.tp_local_rank * logits.shape[-1]
logits_max = torch.max(shift_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=self.tp_pg)
# Determine whether the maximum value of the current local tensor is the global maximum value
logits_global = logits_max == torch.max(shift_logits, dim=-1)[0]
corrects = torch.logical_and(
(shift_labels == (shift_logits.argmax(dim=-1) + pred_shift)), logits_global
).long()
mask = shift_labels.ne(-100).long()
if hasattr(self, "total_type_count"):
ds_acc = scatter(corrects, type_ids, dim=0, reduce="sum")
token_num_type = scatter(mask, type_ids, dim=0, reduce="sum")
if len(ds_acc) < self.total_type_count:
ds_acc = torch.cat([ds_acc, ds_acc.new_zeros(self.total_type_count - len(ds_acc))])
token_num_type = torch.cat(
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
)
self.ds_tokens += token_num_type
sync_tensor = ds_acc
torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
self.ds_right += sync_tensor.view(-1)
acc = corrects.sum()
torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
self.right += acc # Masked_fill is not needed here because -100 is not available anyway
self.total += mask.sum()
# Subtract the maximum value.
shift_logits = shift_logits.sub(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = shift_logits.size()[-1]
vocab_start_index = partition_vocab_size * self.tp_local_rank
vocab_end_index = vocab_start_index + partition_vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (shift_labels < vocab_start_index) | (shift_labels >= vocab_end_index)
masked_target = shift_labels - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = shift_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(shift_labels) # bsz x max_len
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
pred_exp_logits = torch.exp(predicted_logits)
# Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
self.total_log_probs += total_log_probs
def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if hasattr(self, "total_type_count"):
torch.distributed.all_reduce(self.ds_right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.ds_tokens, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if self.tokenizer:
torch.distributed.all_reduce(self.total_bytes, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
acc = round((self.right / self.total).item(), 4)
perplexity = round(torch.exp(self.total_log_probs / self.total).item(), 4)
bits_per_bytes = round((self.total_log_probs / self.total_bytes).item(), 4) if self.tokenizer else 0
if hasattr(self, "total_type_count"):
ds_acc = {}
ds_tokens = {}
for i in range(self.total_type_count):
ds_acc[f"acc/{self.dataset_types[i]}"] = round(
(self.ds_right[i].float() / (self.ds_tokens[i].float() + 1e-5)).item(), 4
)
ds_tokens[f"tokens/{self.dataset_types[i]}"] = self.ds_tokens[i].item()
if reset:
self.right.fill_(0)
self.total.fill_(0)
self.total_log_probs.fill_(0)
self.total_bytes.fill_(0)
if hasattr(self, "total_type_count"):
self.ds_right.fill_(0)
self.ds_tokens.fill_(0)
if self.tokenizer is not None:
res = {"acc": acc, "perplexity": perplexity, "BPB": bits_per_bytes}
else:
res = {"acc": acc, "perplexity": perplexity}
if hasattr(self, "total_type_count"):
res.update(ds_acc)
res.update(ds_tokens)
loss_res = self.loss_with_type_id.get_metric()
res.update(loss_res)
return res
class LossWithTypeId:
"""
Notice the loss value computed here may be not the same with the main info loss,
cause loss here is the reduced result of the data parallel.
"""
def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
self.device = device
self.dp_pg = dp_pg
self.loss = torch.Tensor([0.0]).to(device=device)
self.token_num = torch.Tensor([0.0]).to(device=device)
if dataset_types is not None:
self.dataset_types = dataset_types
self.total_type_count = len(dataset_types)
self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
self.loss_fn = FlashCrossEntropyLoss(
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
)
def update(self, logits, labels, type_ids=None):
with torch.no_grad():
if isinstance(logits, (list, tuple)):
logits = logits[0]
logits = logits.contiguous().view(-1, logits.size(-1))
labels = labels.contiguous().view(-1)
loss_list = self.loss_fn(logits, labels)
cond = labels != -100
real_loss_list = loss_list[cond]
self.loss += real_loss_list.sum()
self.token_num += real_loss_list.numel()
if hasattr(self, "total_type_count"):
type_ids = type_ids.contiguous().view(-1).to(self.device)
real_type_ids = type_ids[cond]
loss_list_type = scatter(real_loss_list, real_type_ids, dim=0, reduce="sum")
token_num_type = scatter(torch.ones_like(real_loss_list), real_type_ids, dim=0, reduce="sum")
if len(loss_list_type) < self.total_type_count:
loss_list_type = torch.cat(
[loss_list_type, loss_list_type.new_zeros(self.total_type_count - len(loss_list_type))]
)
token_num_type = torch.cat(
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
)
self.ds_loss += loss_list_type
self.ds_token_num += token_num_type
def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if hasattr(self, "total_type_count"):
torch.distributed.all_reduce(self.ds_loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.ds_token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
loss = round((self.loss / self.token_num).item(), 4)
res = {
"loss_from_metric": loss,
}
if hasattr(self, "total_type_count"):
ds_loss = {}
for i in range(self.total_type_count):
ds_loss[f"loss/{self.dataset_types[i]}"] = round((self.ds_loss[i] / self.ds_token_num[i]).item(), 4)
res.update(ds_loss)
if reset:
self.loss.fill_(0.0)
self.token_num.fill_(0.0)
if hasattr(self, "total_type_count"):
self.ds_loss.fill_(0.0)
self.ds_token_num.fill_(0.0)
return res

View File

@ -5,7 +5,6 @@ import math
from typing import Optional
import torch
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn
@ -20,7 +19,7 @@ from internlm.model.linear import (
ScaleColumnParallelLinear,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs
@ -30,6 +29,7 @@ from internlm.utils.registry import MODEL_INITIALIZER
MODEL_TYPE = "INTERNLM"
logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm()
class PackedFlashBaseLayer1D(nn.Module):
@ -49,6 +49,7 @@ class PackedFlashBaseLayer1D(nn.Module):
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
device (Optional[Union[str, torch.device]]): The device will be used.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
"""
def __init__(
@ -68,12 +69,14 @@ class PackedFlashBaseLayer1D(nn.Module):
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
):
super().__init__()
self.checkpoint = checkpoint
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
self.layer_idx = layer_idx
self.use_flash_attn = use_flash_attn
head_dim = hidden_size // num_attention_heads
self.mixer = MHA(
@ -86,8 +89,7 @@ class PackedFlashBaseLayer1D(nn.Module):
layer_idx=layer_idx,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
use_flash_attn=True,
sequence_parallel=False,
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
)
@ -119,7 +121,7 @@ class PackedFlashBaseLayer1D(nn.Module):
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
@ -243,6 +245,7 @@ class PackedFlashInternLm1D(nn.Module):
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
@ -271,6 +274,7 @@ class PackedFlashInternLm1D(nn.Module):
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
):
super().__init__()
@ -290,7 +294,7 @@ class PackedFlashInternLm1D(nn.Module):
max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None,
sequence_parallel=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
@ -317,6 +321,7 @@ class PackedFlashInternLm1D(nn.Module):
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
)
for lid in range(num_layers)
]
@ -331,7 +336,6 @@ class PackedFlashInternLm1D(nn.Module):
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
sequence_parallel=False,
device=device,
dtype=dtype,
weight_scale=embed_grad_scale,
@ -397,9 +401,10 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
# all_parts = partition_uniform_with_embed2(num_layers, pipeline_size, num_chunks)
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank]
if gpc.is_rank_for_log():
logger.info(f"The layer sharding is {all_parts}.")
models = []
@ -445,6 +450,8 @@ def build_model_with_cfg(
dropout_selective_checkpoint=True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
sequence_parallel: bool = False, # pylint: disable=W0613
):
"""
Builde model with config
@ -474,6 +481,7 @@ def build_model_with_cfg(
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
@ -496,6 +504,7 @@ def build_model_with_cfg(
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)

View File

@ -12,12 +12,12 @@ from flash_attn.modules.mha import (
SelfAttention,
_update_kv_cache,
)
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.embedding import RotaryEmbedding
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
class MHA(nn.Module):
@ -43,6 +43,7 @@ class MHA(nn.Module):
of x will be done before doing the matmul.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
@ -57,8 +58,7 @@ class MHA(nn.Module):
layer_idx: int = None,
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
use_flash_attn: bool = False,
sequence_parallel: bool = True,
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
@ -77,12 +77,12 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True
self.Wqkv = ColumnParallelLinear(
self.Wqkv = ColumnParallelLinearTorch(
embed_dim,
3 * embed_dim,
process_group,
bias=True,
sequence_parallel=sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
@ -94,8 +94,12 @@ class MHA(nn.Module):
)
# output projection always have the bias (for now)
self.out_proj = RowParallelLinear(
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs
self.out_proj = RowParallelLinearTorch(
embed_dim,
embed_dim,
process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
@ -107,9 +111,9 @@ class MHA(nn.Module):
if kwargs.get("indexes", None) is not None:
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
else:
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params)
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
def _forward(self, x, seqlen=None, inference_params=None):
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
@ -124,13 +128,17 @@ class MHA(nn.Module):
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
if self.rotary_emb_dim > 0:
if inference_params is None:
qkv = self.rotary_emb.eval_forward(qkv)
else:
qkv = self.rotary_emb.eval_forward(qkv, seqlen_offset=inference_params.sequence_len_offset)
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
if inference_params is None:
context = self.inner_attn(qkv)
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv).to(x.dtype)
else:
context = self.inner_attn(qkv)
else:
q = qkv[:, :, 0]
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
@ -158,10 +166,18 @@ class MHA(nn.Module):
"""
qkv = self.Wqkv(x) # total x hsz'
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
qkv = self.rotary_emb(qkv, kwargs.pop("indexes"))
qkv = self.rotary_emb(qkv, **kwargs)
kwargs.pop("indexes")
if inference_params is None:
context = self.inner_attn(qkv, **kwargs)
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
else:
context = self.inner_attn(qkv, **kwargs)
else:
raise RuntimeError("Not support this right now")

46
internlm/model/norm.py Normal file
View File

@ -0,0 +1,46 @@
# adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
import numbers
import torch
from torch.nn import init
from torch.nn.parameter import Parameter
def manual_rms_norm(my_input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
my_input = my_input * torch.rsqrt(variance + eps)
if weight is None:
return my_input
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
my_input = my_input.to(weight.dtype)
return weight * my_input
class RMSNormTorch(torch.nn.Module):
"""A custom PyTorch module for RMS normalization."""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.empty(*normalized_shape))
self.reset_parameters()
def forward(self, _input: torch.Tensor):
return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
def reset_parameters(self):
init.ones_(self.weight)
def extra_repr(self):
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)

View File

@ -1,9 +1,24 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import FusedDenseFunc
from flash_attn.utils.distributed import (
all_gather_raw,
all_reduce_raw,
reduce_scatter_raw,
)
from torch import Tensor
from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def _split(input_, parallel_mode, dim=-1):
@ -71,3 +86,124 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
assert my_input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), my_input)
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
return grad_weight, grad_bias
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc):
"""A custom PyTorch module extending FusedDenseFunc."""
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
else:
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if process_group is not None and sequence_parallel:
handle_x.wait()
# we remove the cuda independence, which is different from flash_attn.
grad_weight, grad_bias = linear_bias_wgrad_torch(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def fused_dense_func_torch(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
return_residual: bool = False,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True,
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(input_):
return _split(input_, parallel_mode=None)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _split(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
def try_import_RMSNorm():
"""
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
"""
try:
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
return RMSNorm
except ModuleNotFoundError:
logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm

View File

@ -0,0 +1,4 @@
from .monitor import initialize_monitor_manager, send_alert_message
from .utils import set_env_var
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]

53
internlm/monitor/alert.py Normal file
View File

@ -0,0 +1,53 @@
import json
import time
import requests
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
"""
Use Feishu robot to send messages with the given webhook.
Args:
webhook (str): The webhook to be used to send message.
title (str): The message title.
message (str): The message body.
Returns:
The response from the request. Or catch the exception and return None.
Raises:
Exception: An exception rasied by the HTTP post request.
"""
headers = {"Content-Type": "application/json;charset=utf-8"}
msg_body = {
"timestamp": int(time.time()),
"msg_type": "post",
"content": {
"post": {
"zh_cn": {
"title": title,
"content": [
[
{
"tag": "text",
"text": message,
},
],
],
},
},
},
}
try:
res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30)
res = res.json()
print(f"Feishu webhook response: {res}")
except Exception as err: # pylint: disable=W0703
print(f"HTTP Post error: {err}")
res = None
return res

226
internlm/monitor/monitor.py Normal file
View File

@ -0,0 +1,226 @@
import os
import signal
import socket
import time
from contextlib import contextmanager
from threading import Thread
from internlm.core.context import global_context as gpc
from internlm.monitor.alert import send_feishu_msg_with_webhook
from internlm.utils.common import SingletonMeta
from .utils import get_job_key, set_env_var
def send_alert_message(address: str = None, title: str = None, message: str = None):
"""
Send alert messages to the given Feishu webhook address in log rank.
Args:
address (str): The alert address to be used to send message, defaults to None.
title (str): The message title, defaults to None.
message (str): The message body, defaults to None.
"""
if address is not None and gpc.is_rank_for_log():
send_feishu_msg_with_webhook(
webhook=address,
title=title if title else get_job_key(),
message=message,
)
class MonitorTracker(Thread):
"""
Track job status and alert to Feishu during job training.
Args:
alert_address (str): The Feishu webhook address for sending alerting messages.
check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
"""
def __init__(
self,
alert_address: str,
check_interval: float = 300,
loss_spike_limit: float = 1.5,
):
super().__init__()
self.alert_address = alert_address
self.check_interval = check_interval
self.loss_spike_limit = loss_spike_limit
self.last_active_time = -1
self.last_loss_value = -1
self.stopped = False
self.start()
def run(self):
"""
start the monitor tracker.
"""
while not self.stopped:
try:
self._check_stuck()
self._check_loss_spike()
except Exception:
continue
time.sleep(self.check_interval)
def _check_stuck(self):
"""
Check training status for potential stuck condition.
"""
new_active_time = -1
if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
self._send_alert("Training may be in stuck status, please check it.")
self.last_active_time = new_active_time
def _check_loss_spike(self):
"""
Check for loss value spikes.
"""
if gpc.is_rank_for_log():
new_loss_value = -1
new_step_id = -1
if os.getenv("LOSS") is not None:
new_loss_value = os.getenv("LOSS")
if os.getenv("STEP_ID") is not None:
new_step_id = os.getenv("STEP_ID")
if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
assert int(new_step_id) >= 0
self._send_alert(
f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
)
self.last_loss_value = new_loss_value
def _send_alert(self, message):
"""
Send alerting message to the Feishu webhook address.
Args:
message (str): The alerting message to be sent.
"""
send_alert_message(
address=self.alert_address,
message=message,
)
def stop(self):
"""
Stop the monitor tracker.
"""
self.stopped = True
class MonitorManager(metaclass=SingletonMeta):
"""
Monitor Manager for managing monitor thread and monitoring training status.
"""
def __init__(self, loss_spike_limit: float = 1.5) -> None:
self.monitor_thread = None
self.loss_spike_limit = loss_spike_limit
self.last_step_loss = -1
def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
"""Check loss value, if loss spike occurs, send alert message to Feishu."""
set_env_var(key="LOSS", value=cur_step_loss)
set_env_var(key="STEP_ID", value=step_count)
if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
send_alert_message(
address=alert_address,
message=(
f"Checking step by step: Loss spike may be happened in step {step_count}, "
f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
),
)
self.last_step_loss = cur_step_loss
def monitor_exception(self, alert_address: str = None, excp_info: str = None):
"""Catch and format exception information, send alert message to Feishu."""
filtered_trace = excp_info.split("\n")[-10:]
format_trace = ""
for line in filtered_trace:
format_trace += "\n" + line
send_alert_message(
address=alert_address,
message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
)
def handle_sigterm(self, alert_address: str = None):
"""Catch SIGTERM signal, and send alert message to Feishu."""
def sigterm_handler(sys_signal, frame):
print("receive frame: ", frame)
print("receive signal: ", sys_signal)
send_alert_message(
address=alert_address,
message=f"Process received signal {signal} and exited.",
)
signal.signal(signal.SIGTERM, sigterm_handler)
def start_monitor(
self,
job_name: str,
alert_address: str,
monitor_interval_seconds: int = 300,
loss_spike_limit: float = 1.5,
):
"""
Initialize and start monitor thread for checking training job status, loss spike and so on.
Args:
job_name (str): The training job name.
alert_address (str): The Feishu webhook address for sending alert messages.
monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
may be occurs, defaults to 1.5.
"""
# initialize some variables for monitoring
set_env_var(key="JOB_NAME", value=job_name)
# start a monitor thread, periodically check the training status
self.monitor_thread = MonitorTracker(
alert_address=alert_address,
check_interval=monitor_interval_seconds,
loss_spike_limit=loss_spike_limit,
)
def stop_monitor(self):
"""Stop the monitor and alert thread."""
if self.monitor_thread is not None:
self.monitor_thread.stop()
monitor_manager = MonitorManager()
@contextmanager
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
if alert_address is not None:
try:
monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
monitor_manager.handle_sigterm(alert_address=alert_address)
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
yield
finally:
send_alert_message(
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
)
monitor_manager.stop_monitor()
else:
yield

32
internlm/monitor/utils.py Normal file
View File

@ -0,0 +1,32 @@
import os
from datetime import datetime
def now_time():
return datetime.now().strftime("%b%d_%H-%M-%S")
def set_env_var(key, value):
os.environ[str(key)] = str(value)
def get_job_id():
job_id = "none"
if os.getenv("SLURM_JOB_ID") is not None:
job_id = os.getenv("SLURM_JOB_ID")
elif os.getenv("K8S_WORKSPACE_ID") is not None:
job_id = os.getenv("K8S_WORKSPACE_ID")
return job_id
def get_job_name():
job_name = f"unknown-{now_time()}"
if os.getenv("JOB_NAME") is not None:
job_name = os.getenv("JOB_NAME")
return job_name
def get_job_key():
return f"{get_job_id()}_{get_job_name()}"

View File

@ -3,15 +3,15 @@
import math
from functools import partial
from itertools import product
import amp_C
import torch
import torch.distributed as dist
from apex.multi_tensor_apply import multi_tensor_applier
from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
GradientStore,
@ -20,6 +20,7 @@ from internlm.solver.optimizer.store import (
)
from internlm.solver.optimizer.utils import (
DynamicGradScaler,
ParamBcastSyncHandler,
flatten,
get_grad_accumulate_object,
has_inf_or_nan,
@ -28,33 +29,16 @@ from internlm.solver.optimizer.utils import (
split_half_float_double,
sync_param,
)
from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import is_model_parallel_parameter
from .utils import compute_norm
inf = math.inf
logger = get_logger(__file__)
def calc_l2_norm(grads):
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
)
return norm
def calc_lp(grads, norm_type):
norm = 0.0
for grad in grads:
grad_norm = torch.norm(grad, norm_type)
norm += grad_norm**norm_type
return norm
class BaseOptimizer(Optimizer):
"""
Base Optimizer.
@ -105,12 +89,15 @@ class HybridZeroOptimizer(BaseOptimizer):
self,
optimizer: Optimizer,
cpu_offload=False,
overlap_broadcast=False,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
param_bcast_sync_handler: ParamBcastSyncHandler = None,
):
# DynamicGradScaler related args
initial_scale = grad_scal_cfg.fp16.initial_scale
if gpc.config.model.dtype is torch.float32:
initial_scale = 1
else:
initial_scale = grad_scal_cfg.fp16.initial_scale
min_scale = grad_scal_cfg.fp16.min_scale
growth_interval = grad_scal_cfg.fp16.growth_interval
growth_factor = grad_scal_cfg.growth_factor
@ -119,9 +106,10 @@ class HybridZeroOptimizer(BaseOptimizer):
max_scale = grad_scal_cfg.max_scale
# Zero related args
overlap_communication = zero_cfg.zero_overlap_communication
reduce_bucket_size = zero_cfg.reduce_bucket_size
clip_grad_norm = zero_cfg.clip_grad_norm
self._overlap_sync_grad = zero_cfg.overlap_sync_grad
self._overlap_sync_param = zero_cfg.overlap_sync_param
super().__init__(optim=optimizer)
@ -142,7 +130,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._fp32_flat_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
# self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
# gradient scaler
@ -173,7 +161,12 @@ class HybridZeroOptimizer(BaseOptimizer):
+ f"zo-{self._zero_local_rank}.pt"
)
self.params_per_rank_id_dict = []
self.overlap_broadcast = overlap_broadcast
self._param_bcast_sync_handler = param_bcast_sync_handler
if self._overlap_sync_param:
assert self._param_bcast_sync_handler is not None
self._broadcast_comm_stream = torch.cuda.Stream()
else:
self._broadcast_comm_stream = torch.cuda.current_stream()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@ -195,6 +188,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if len(params) != 0:
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
for param in params:
setattr(param, "group_id", group_id)
self._param_store.set_param_to_rank(param, rank)
# move to cpu to make room to create the flat tensor
@ -240,14 +234,16 @@ class HybridZeroOptimizer(BaseOptimizer):
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
self.skip_grad_reduce = False
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
# initialize communication stream for
# communication-computation overlapping
if self._overlap_sync_grad:
self._comm_stream = torch.cuda.Stream()
else:
self._comm_stream = torch.cuda.current_stream()
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_communication:
if self._overlap_sync_grad:
self._attach_reduction_hook()
@property
@ -281,8 +277,10 @@ class HybridZeroOptimizer(BaseOptimizer):
global_id = str(i)
for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])])
rank_to_go = numel_per_rank.index(min(numel_per_rank))
if self._overlap_sync_param:
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else:
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
numel_per_rank[rank_to_go] += param.numel()
@ -313,7 +311,9 @@ class HybridZeroOptimizer(BaseOptimizer):
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial(
self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank
self._store_and_try_reduce_grads_by_bucket,
param=param,
reduce_rank=reduce_rank,
)
# define hook
@ -334,7 +334,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_stored_in_bucket(reduce_rank)
self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
@ -352,7 +352,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
def _reduce_grads_stored_in_bucket(self, reduce_rank=None):
def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False):
# reduce grads
self._reduce_grads_by_rank(
reduce_rank=reduce_rank,
@ -360,30 +360,27 @@ class HybridZeroOptimizer(BaseOptimizer):
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
)
# use communication stream if overlapping
# communication with computation
if self._overlap_communication:
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
with torch.cuda.stream(stream):
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced = self._param_store.is_param_reduced(param)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = (
f"Parameter of size ({param.size()}) has been reduced, "
+ "duplicate reduction will lead to arithmetic incorrectness"
)
raise RuntimeError(msg)
if is_param_reduced:
msg = (
f"Parameter of size ({param.size()}) has been reduced, "
+ "duplicate reduction will lead to arithmetic incorrectness"
)
raise RuntimeError(msg)
# update the flag
self._param_store.set_param_reduction_state(param, True)
# update the flag
self._param_store.set_param_reduction_state(param, True)
if self._param_store.belongs_to_current_rank(param):
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
else:
self._param_store.add_previous_reduced_param(param)
self._bucket_store.reset_by_rank(reduce_rank)
@ -401,17 +398,17 @@ class HybridZeroOptimizer(BaseOptimizer):
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
if self._overlap_sync_grad:
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
with torch.cuda.stream(self._comm_stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA
tensor=flat,
dtype=self.dtype,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.DATA,
)
# update the reduced tensor
@ -438,6 +435,7 @@ class HybridZeroOptimizer(BaseOptimizer):
reduction_states = self._param_store.get_param_reduction_states()
for tensor, _ in reduction_states.items():
reduction_states[tensor] = False
self._param_store.reset_reduced_data_for_compute_norm()
# accumulate gradient
avg_gradients = self._grad_store._averaged_gradients
@ -486,6 +484,30 @@ class HybridZeroOptimizer(BaseOptimizer):
# Gradients may not be fully synchronized here.
def _compute_norm_with_stage(
self,
group_id: int = 0,
last_bucket: bool = False,
last_stage: bool = False,
previous_norm=None,
):
# compute norm for gradients that have been reduced
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
if len(params) == 0:
grads = [self.padding_grad]
params = [self.padding_tensor]
if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large
norm = compute_norm(
gradients=grads,
parameters=params,
last_stage=last_stage,
previous_norm=previous_norm,
)
return norm
def step(self, closure=None):
"""Performs a single optimization step.
@ -497,88 +519,92 @@ class HybridZeroOptimizer(BaseOptimizer):
"""
assert closure is None, "closure is not supported by step()"
timer("sync_grad").start()
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
if not self._overlap_communication:
if not self._overlap_sync_grad:
for group_id in range(len(self._fp16_param_groups)):
for param in self._fp16_param_groups[group_id]:
if param.grad is not None:
self._store_and_try_reduce_grads_by_bucket(param)
# we need to reduce the gradients left in the communication bucket
self._reduce_grads_stored_in_bucket()
self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True)
# compute norm for gradients in the before bucket
groups_norms = []
for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
if self._overlap_sync_grad:
# grads in the last bucket is reduced
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# compute norm for gradients in the last bucket
total_norms = []
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
)
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()
return self._step(closure=closure)
return self._step(closure=closure, norms=total_norms)
def _step(self, closure=None):
def _step(self, closure=None, norms=None):
assert closure is None, "closure is not supported by step()"
# check for overflow
found_inf = self._check_overflow()
found_inf = False
# if there is INF values in grades, compute_norm func would also returns -1
# thus, we try to avoid call _check_overflow here
# found_inf = self._check_overflow()
# Because you may encounter inf when computing norm
timer("cal_norm").start()
norm_groups = []
for group_id in range(self.num_param_groups):
# compute norm
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
parameters = self._param_store.get_fp16_params_by_rank_group(
group_id=group_id, rank=self._zero_local_rank
)
else:
# in order to prevent collection communication from hanging,
# we need to involve rank that are not assigned parameters in compute_norm(),
# so we give them a fp16 vector of 0 values.
gradients = [self.padding_grad]
parameters = [self.padding_tensor]
if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large
norm_group = compute_norm(
gradients=gradients,
parameters=parameters,
)
if norm_group == -1:
timer("cal_norm").stop()
found_inf = True
break
norm_groups.append(norm_group)
if -1 in norms:
found_inf = True
loss_scale = float(self.loss_scale.item()) # backup
self.grad_scaler.update(found_inf)
if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
send_alert_message(
address=gpc.config.alert_address,
message="Overflow occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, None
return False, norms
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
global_norm = 0
for group_id in range(self.num_param_groups):
# compute norm
# The following operations are performed only on the rank to which parameters are assigned.
if not self.param_group_has_params[group_id]:
continue
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
# create flat gradient for the flat fp32 params
fp16_avg_grads = gradients
flat_fp16_avg_grads = flatten(fp16_avg_grads)
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
with torch.no_grad():
flat_fp16_avg_grads = flatten(gradients)
self._grad_store.reset_average_gradients_by_group(group_id)
gradients = None # release cuda memory
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
flat_fp16_avg_grads = None # release cuda memory
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
assert (
@ -588,19 +614,19 @@ class HybridZeroOptimizer(BaseOptimizer):
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = []
# unscale and clip grads
# get the global norm
global_norm_groups = []
if self._clip_grad_norm > 0:
global_norm = sum(norm_groups) ** 0.5
for norm in norms:
global_norm_groups.append(norm**0.5)
# the following operations are performed only on the rank to which parameters are assigned.
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
timer("cal_norm").stop()
# update the parameters
timer("step").start()
@ -619,35 +645,40 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
# TODO: support broadcast overlap
self.broadcast_params(overlap=False)
with torch.cuda.stream(self._broadcast_comm_stream):
self.broadcast_params()
timer("step").stop()
# update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained
return True, global_norm / loss_scale
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
def broadcast_params(self, overlap=False):
def broadcast_params(self):
handles = []
for group_id in range(self.num_param_groups):
for rank in range(self._zero_world_size):
# The following operations are performed only on the rank to which parameters are assigned.
if rank not in self.param_group_no_params_ranks[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast(
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
)
handles.append(handle)
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
# The following operations are performed only on the rank to which parameters are assigned.
if rank in self.param_group_no_params_ranks[group_id]:
continue
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast(
fp16_param,
src=g_rank,
group=gpc.get_group(ParallelMode.ZERO1),
async_op=True,
)
if not overlap:
for handle in handles:
handle.wait()
else:
return handles
if self._overlap_sync_param:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handles.append(handle)
for handle in handles:
handle.wait()
##################
# FP16 Utilities #
@ -665,22 +696,28 @@ class HybridZeroOptimizer(BaseOptimizer):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
dist.all_reduce(
self._found_overflow,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.GLOBAL),
)
return self._found_overflow.item() > 0
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm, loss_scale):
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
# compute combined scale factor for this group
combined_scale = loss_scale
combined_scale_groups = []
if self._clip_grad_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1.0:
combined_scale = clip * loss_scale
for group_id, total_norm in enumerate(total_norm_groups):
combined_scale_groups.append(loss_scale)
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1.0:
combined_scale_groups[group_id] = clip * loss_scale
for grad in grad_groups_flat:
grad.data.mul_(1.0 / combined_scale)
for group_id, grad in enumerate(grad_groups_flat):
grad.data.mul_(1.0 / combined_scale_groups[group_id])
def clip_grad_norm(self, model, max_norm):
# will conduct in the step()
@ -733,87 +770,3 @@ class HybridZeroOptimizer(BaseOptimizer):
if "zero_devide_optim_plan" in states:
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
def compute_norm(gradients, parameters, norm_type=2):
"""Get the norm
Arguments:
gradients (Iterable[Tensor]): The gradient value.
parameters (Iterable[Tensor]): The parameter each gradient corresponds to.
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters, need total_norm**(1/norm) before using.
"""
enable_cuda_kernels = gradients[0].device.type == "cuda"
# Norm parameters.
norm_type = float(norm_type)
# Calculate norm.
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
# Take max across all model-parallel GPUs.
if gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
for g, p in zip(gradients, parameters):
# TODO: consider the pipeline shared parameter
if (
gpc.is_initialized(ParallelMode.PIPELINE)
and hasattr(p, "pipeline_shared_module_pg")
and dist.get_rank(p.pipeline_shared_module_pg) == 0
): # if shared between different pipe, only count o
tensor_parallel_grads.append(g.data.float())
elif (
gpc.is_initialized(ParallelMode.PIPELINE)
and hasattr(p, "pipeline_shared_module_pg")
and dist.get_rank(p.pipeline_shared_module_pg) != 0
):
continue
elif (
gpc.is_initialized(ParallelMode.TENSOR)
and not is_model_parallel_parameter(p)
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
): # if not used in each chunk, such as layernorm
tensor_parallel_grads.append(g.data.float())
elif is_model_parallel_parameter(p):
tensor_parallel_grads.append(g.data.float())
elif gpc.get_local_rank(ParallelMode.TENSOR) != 0:
continue
else:
raise RuntimeError("Should not arrive here")
if norm_type == 2.0 and enable_cuda_kernels:
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
else:
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
# If norm is type of float, then we convert them into torch.Tensor.
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
if not enable_cuda_kernels:
tensor_parallel_norm = move_norm_to_cuda(tensor_parallel_norm)
total_norm = tensor_parallel_norm
# Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
# This is because we use zero1, so we need to use this reduction.
# TODO: Check zero group to be a subset of dp group.
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
if torch.is_tensor(total_norm):
total_norm = total_norm.item()
# Scale.
if total_norm == float("inf") or total_norm == -float("inf"):
total_norm = -1
return total_norm

View File

@ -152,6 +152,11 @@ class ParameterStore(BaseStore):
self._is_param_reduced = dict()
self._reduced_param = []
self._former_bucket_reduced_param = {}
self._last_bucket_reduced_param = {}
self._former_bucket_reduced_grad = {}
self._last_bucket_reduced_grad = {}
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
"""
Set the mapping between parameter to rank, each parameter should be owned by a rank.
@ -223,6 +228,39 @@ class ParameterStore(BaseStore):
def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor)
def add_reduced_param_for_compute_norm(self, param, last_bucket=False):
group_id = getattr(param, "group_id")
if last_bucket:
if group_id not in self._last_bucket_reduced_param:
self._last_bucket_reduced_param[group_id] = []
self._last_bucket_reduced_grad[group_id] = []
self._last_bucket_reduced_param[group_id].append(param)
self._last_bucket_reduced_grad[group_id].append(param.grad)
else:
if group_id not in self._former_bucket_reduced_param:
self._former_bucket_reduced_param[group_id] = []
self._former_bucket_reduced_grad[group_id] = []
self._former_bucket_reduced_param[group_id].append(param)
self._former_bucket_reduced_grad[group_id].append(param.grad)
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
if not last_bucket:
if group_id not in self._former_bucket_reduced_param:
return [], []
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
else:
if group_id not in self._last_bucket_reduced_param:
return [], []
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
def reset_reduced_data_for_compute_norm(self):
self._former_bucket_reduced_param = {}
self._last_bucket_reduced_param = {}
self._former_bucket_reduced_grad = {}
self._last_bucket_reduced_grad = {}
def clear_grads_of_previous_reduced_params(self):
if len(self._reduced_param) > 0:
for param in self._reduced_param:

View File

@ -1,20 +1,37 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from abc import ABC, abstractmethod
from typing import Dict, Optional
from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed as dist
from torch import Tensor
from torch import Tensor, nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_model_parallel_parameter
logger = get_logger(__file__)
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
APEX_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
logger.warning("The torch implementation for cal_l2norm is slower than apex. Please note this!")
APEX_AVAILABLE = False
inf = math.inf
def flatten(input_):
return _flatten_dense_tensors(input_)
@ -46,12 +63,19 @@ def get_grad_accumulate_object(tensor):
def split_half_float_double(tensor_list):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = []
for _, dtype in enumerate(dtypes):
bucket = [t for t in tensor_list if t.type() == dtype]
if bucket:
buckets.append(bucket)
dtype_buckets = {
"torch.cuda.HalfTensor": [],
"torch.cuda.FloatTensor": [],
"torch.cuda.DoubleTensor": [],
"torch.cuda.BFloat16Tensor": [],
}
for t in tensor_list:
dtype = t.type()
if dtype in dtype_buckets:
dtype_buckets[dtype].append(t)
buckets = [bucket for bucket in dtype_buckets.values() if bucket]
return buckets
@ -150,6 +174,149 @@ def sync_param(flat_tensor, tensor_list):
p.data = q.data
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
# Convert tensor_list elements to torch.float32
tensor_list = [tensor.float() for tensor in tensor_list]
norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list])
l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0)
if per_tensor:
per_tensor_norm = norms_tensor
else:
per_tensor_norm = torch.Tensor([]).to(norms_tensor.device)
return l2_norm, per_tensor_norm
def calc_l2_norm(grads):
norm = 0.0
if len(grads) > 0:
if APEX_AVAILABLE:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False, # no per-parameter norm
)
else:
norm, _ = multi_tensor_l2norm_torch(grads, False)
return norm
def calc_lp(grads, norm_type):
norm = 0.0
for grad in grads:
grad_norm = torch.norm(grad, norm_type)
norm += grad_norm**norm_type
return norm
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2):
"""Get the norm
Arguments:
gradients (Iterable[Tensor]): The gradient value.
parameters (Iterable[Tensor]): The parameter each gradient corresponds to.
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters, need total_norm**(1/norm) before using.
"""
enable_cuda_kernels = gradients[0].device.type == "cuda"
# Norm parameters.
norm_type = float(norm_type)
# Calculate norm.
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
if last_stage is False:
return total_norm_cuda
if previous_norm is not None:
total_norm_cuda = max(total_norm_cuda, previous_norm)
# Take max across all model-parallel GPUs.
if gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(
total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL),
)
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
for g, p in zip(gradients, parameters):
# TODO: consider the pipeline shared parameter
if (
gpc.is_initialized(ParallelMode.PIPELINE)
and hasattr(p, "pipeline_shared_module_pg")
and dist.get_rank(p.pipeline_shared_module_pg) == 0
): # if shared between different pipe, only count o
tensor_parallel_grads.append(g.data.float())
elif (
gpc.is_initialized(ParallelMode.PIPELINE)
and hasattr(p, "pipeline_shared_module_pg")
and dist.get_rank(p.pipeline_shared_module_pg) != 0
):
continue
elif (
gpc.is_initialized(ParallelMode.TENSOR)
and not is_model_parallel_parameter(p)
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
): # if not used in each chunk, such as layernorm
tensor_parallel_grads.append(g.data.float())
elif is_model_parallel_parameter(p):
tensor_parallel_grads.append(g.data.float())
elif gpc.get_local_rank(ParallelMode.TENSOR) != 0:
continue
else:
raise RuntimeError("Should not arrive here")
if norm_type == 2.0 and enable_cuda_kernels:
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
else:
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
# If norm is type of float, then we convert them into torch.Tensor.
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
if not enable_cuda_kernels:
tensor_parallel_norm = move_norm_to_cuda(tensor_parallel_norm)
total_norm = tensor_parallel_norm
if last_stage is False:
return total_norm
if previous_norm is not None:
total_norm = total_norm + previous_norm
# Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(
total_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.MODEL),
)
# This is because we use zero1, so we need to use this reduction.
# TODO: Check zero group to be a subset of dp group.
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
if torch.is_tensor(total_norm):
total_norm = total_norm.item()
# Scale.
if total_norm == float("inf") or total_norm == -float("inf"):
total_norm = -1
return total_norm
class BaseGradScaler(ABC):
"""A base class for the gradient scaler.
@ -313,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler):
self._scale = self._scale.fill_(state_dict["_scale"])
self._growth_step = state_dict["_growth_step"]
self._hysteresis_step = state_dict["_hysteresis_step"]
class ParamBcastSyncHandler:
"""
Model Partition Handler for overlap broadcast with forward
"""
def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None:
self._block_to_param = OrderedDict() # <key: nn.Module> <value: list(param)>
self._param_to_rank = dict() # <key: param> <value: rank)>
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
total_param_num = sum(p.numel() for p in model.parameters())
avg_param_num = total_param_num * 1.0 // zero1_size
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
# record the parameters to transformer/embeding/head/norm block
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, children in _chunk.named_children():
# should be the transformer block definaton in modeling_xxx.py
if isinstance(children, nn.ModuleList):
# record the block that a parameter belongs to
for _, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
else:
# record the block that a parameter belongs to
# self._block_to_param[name] = list(children.parameters())
self._block_to_param[children] = list(children.parameters())
alloc_num = 0
rank_to_go = 0
# process the parameters in block_to_param sequencially,
# allocate each parameter to a local rank of ParallelMode.ZERO1,
# NOTE that we do NOT consider following scenarios:
# 1) whether a parameter is trainable;
# 2) paramters maybe in different optimizer group
for block, params in self._block_to_param.items():
# allocate a model block to a local rank of ParallelMode.ZERO1
self._block_to_rank[block] = [rank_to_go]
for p in params:
alloc_num = alloc_num + p.numel()
# in this case, allocate the param to next rank if possible
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
rank_to_go = rank_to_go + 1
alloc_num = 0
self._block_to_rank[block].append(rank_to_go)
# allocate a parameter to a local rank of ParallelMode.ZERO1
self._param_to_rank[p] = rank_to_go
# initialize an empty list for _bcast_handles of each rank
for rank in range(gpc.get_world_size(ParallelMode.ZERO1)):
self._bcast_handles[rank] = []
# register_forward_pre_hook for transformer/embeding/norm/xxx block
self._register_sync_parameters_hook()
def _register_sync_parameters_hook(self) -> None:
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
bcast_handles = []
# gather all required broadcast hanles into a list
for rank in self._block_to_rank[model]:
bcast_handles.extend(self._bcast_handles[rank])
# need to clear _bcast_handles since they would be processed later
self._bcast_handles[rank] = []
# wait all required broadcast handles to be completed
for handle in bcast_handles:
handle.wait()
# register_forward_pre_hook for transformer/embeding/norm/xxx block
for block, _ in self._block_to_rank.items():
block.register_forward_pre_hook(partial(_pre_forward_hook))
def get_rank_by_param(self, param) -> int:
return self._param_to_rank[param]
def add_bcast_handle(self, rank, handle) -> None:
self._bcast_handles[rank].append(handle)

View File

@ -0,0 +1,19 @@
from .training_internlm import (
get_train_data_loader,
get_validation_data_loader,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
)
__all__ = [
"get_train_data_loader",
"get_validation_data_loader",
"initialize_llm_profile",
"initialize_model",
"initialize_optimizer",
"load_new_batch",
"record_current_batch_training_metrics",
]

View File

@ -0,0 +1,414 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import time
from functools import partial
from typing import Callable, Iterable, Union
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
from internlm.data.dataset import get_dataset_dict
from internlm.data.dummy_dataset import RandomDataset
from internlm.data.packed_dataset import (
PackedDataset,
PackedDatasetWithoutCuSeqlen,
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.monitor import set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import (
is_no_pp_or_last_stage,
sync_model_param,
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER
logger = get_logger(__file__)
def initialize_model():
"""
Initialize model.
Returns: The neural network model to be trained or evaluated.
"""
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
if isinstance(model, nn.ModuleList):
model = nn.ModuleList(
[
NaiveAMPModel(
model=_m,
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
for _m in model
]
)
else:
model = NaiveAMPModel(
model=model,
output_to_fp32=is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
# This sync is very important, cause the model weights kept in optimizer are copied
# from the origin parameters in the memory, so we should make sure the dp sync
# does not influence the model weights in optimizer be different with the origin parameters.
sync_model_param(model, parallel_mode=ParallelMode.DATA)
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
# the same across tensor parallelism.
sync_model_param_within_tp(model)
return model
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
"""
Initialize optimizer.
Args:
model (torch.nn.Module): Your model instance to be trained or evaluated.
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
"""
if gpc.config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(model)
else:
param_bcast_sync_handler = None
adam_cfg = gpc.config.adam
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps,
)
optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
return optimizer, beta2_scheduler, lr_scheduler
def get_train_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
):
"""
Generate and return the training data loader.
Returns: A tuple of (train_dl, dataset_types).
"""
# Get the dataset types
dataset_types = None
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
data_cfg = gpc.config.data
# Get the sample weight dictionary
train_folder = data_cfg.train_folder
if not train_folder:
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
if dataset_generate_func is not None:
train_ds = dataset_generate_func()
else:
train_ds = get_packed_dataset_without_short_length(
folder=data_cfg.train_folder,
packed_length=data_cfg.packed_length,
max_length_per_sample=data_cfg.seq_len,
show_progress=dist.get_rank() == 0,
min_length=data_cfg.min_length,
min_length_dict=data_cfg.get("min_length_dict", {}),
pack_into_one_sample=data_cfg.pack_sample_into_one,
)
if dataset_generate_func is None or not train_folder:
# partition already completed
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
# Create the training dataset sampler
train_sampler = StaticBatchSampler(
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num,
rampup_batch_size=data_cfg.rampup_batch_size,
micro_bsz=data_cfg.micro_bsz,
seed=1024,
drop_last=True,
data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)
if dataset_generate_func is None or not train_folder:
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
# Create the training data loader
train_dl = DataLoader(
dataset=train_ds,
batch_sampler=train_sampler,
num_workers=num_worker,
pin_memory=True,
collate_fn=train_collate_fn,
persistent_workers=num_worker > 0,
)
return train_dl, dataset_types
def get_validation_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
):
"""Generate and return the validation data loader."""
data_cfg = gpc.config.data
if not data_cfg.valid_folder:
val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len)
else:
if dataset_generate_func is not None:
assert val_collate_fn and dataloader_func is not None
val_ds = dataset_generate_func()
else:
val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="")
if not isinstance(val_ds, dict):
val_ds = {"val": val_ds}
if val_collate_fn is None or not data_cfg.valid_folder:
val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len)
val_dls = {}
for val_name, ds in val_ds.items():
if dataloader_func and data_cfg.valid_folder is not None:
val_dls[val_name] = dataloader_func(dataset=ds, collate_fn=val_collate_fn)
if gpc.is_rank_for_log():
logger.info(
f"load validation dataset {val_name} with valid batch size {str(data_cfg.valid_micro_num)} and "
f"{ds.size} Byte samples."
)
else:
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
# otherwise too much data may be dropped
batch_size = min(
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
)
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
if batch_size == 0 and gpc.is_rank_for_log():
logger.info(f"skip validate {val_name}.")
continue
val_dls[val_name] = get_dpsampler_dataloader(
ds,
shuffle=False,
num_workers=num_worker,
batch_size=batch_size,
collate_fn=val_collate_fn,
drop_last=True,
) # drop_last=True, otherwise it may cause problems in the last batch
if gpc.is_rank_for_log():
logger.info(
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
f"samples {str(len(val_dls[val_name]))}."
)
return val_dls
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
"""
Load and return the new batch data based on training data loader.
Args:
train_dl (torch.utils.data.DataLoader): Dataloader for training.
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
train_state (TrainState): Current training state.
Returns: A batch data and the updated train_iter.
"""
timer("batch-gen").start()
try:
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
if hasattr(train_state, "batch_sampler_iter"):
next(train_state.batch_sampler_iter)
except StopIteration:
train_iter = iter(train_dl)
batch = next(train_iter)
train_state.num_consumed_samples_in_epoch = 0
if hasattr(train_state, "batch_sampler"):
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
next(train_state.batch_sampler_iter)
timer("batch-gen").stop()
if batch[0].get("type_ids", None) is not None:
# if use_flash_attn is False, we need to unpack type_ids
if not gpc.config.model.use_flash_attn:
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
return batch, train_iter
def initialize_llm_profile(profiling: bool = False, start_time: str = None):
"""Initialize and return the profiler context manager instance."""
if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
llm_profile = torch.profiler.profile
logger.info(f"Do profiling in rank {gpc.get_global_rank()}!")
else:
llm_profile = DummyProfile
return llm_profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
+ f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
),
with_stack=True,
with_modules=True,
)
def record_current_batch_training_metrics(
get_tflops_func,
logger,
writer,
success_update,
batch_count,
batch,
train_state,
optimizer,
beta2_scheduler,
trainer,
start_time,
loss,
grad_norm,
metric,
update_panel,
):
"""
Print some training metrics of current batch.
"""
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
acc_perplex = metric.get_metric()
if success_update and gpc.is_rank_for_log():
lr = optimizer.param_groups[0]["lr"]
if hasattr(trainer.engine.optimizer, "grad_scaler"):
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
num_tokens_in_batch = batch[1].nelement()
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
tk_per_gpu = 0
tk_per_gpu = round(
num_tokens_in_batch
* gpc.get_world_size(ParallelMode.DATA)
/ gpc.get_world_size(ParallelMode.GLOBAL)
/ (time.time() - start_time),
2,
)
tflops = get_tflops_func((time.time() - start_time))
infos = {
"tflops": tflops,
"step": batch_count,
"loss": loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr,
"loss_scale": scaler,
"grad_norm": grad_norm,
}
infos["micro_num"] = len(batch[1])
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
infos["largest_length"] = max_length_in_batch # the longest input
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
infos["smallest_batch"] = min_samples_in_batch
infos["adam_beta2"] = beta2_scheduler.get_beta2()
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
infos["fwd_bwd_time"] = fwd_bwd_time
for key, value in acc_perplex.items():
infos[key] = value
line = ""
for key, value in infos.items():
line += f"{key}={value} "
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if update_panel:
logger.info(
line,
extra={
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
},
)
else:
logger.info(line)
# if loss spike occurs, send alert info to feishu
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())

View File

@ -34,18 +34,6 @@ def get_master_node():
return result
def get_process_rank():
proc_rank = -1
if os.getenv("SLURM_PROCID") is not None:
proc_rank = int(os.getenv("SLURM_PROCID"))
elif os.getenv("RANK") is not None:
# In k8s env, we use $RANK.
proc_rank = int(os.getenv("RANK"))
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
return proc_rank
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != "cuda":
norm = norm.to(torch.cuda.current_device())
@ -81,28 +69,12 @@ def move_to_device(data):
data_to_return = []
for element in data:
if isinstance(element, dict):
data_to_return.append(
{
k: (
_move_tensor(v)
if k != "inference_params"
else v._replace(attention_mask=_move_tensor(v.attention_mask))
)
for k, v in element.items()
}
)
data_to_return.append({k: _move_tensor(v) for k, v in element.items()})
else:
data_to_return.append(_move_tensor(element))
data = data_to_return
elif isinstance(data, dict):
data = {
k: (
_move_tensor(v)
if k != "inference_params"
else v._replace(attention_mask=_move_tensor(v.attention_mask))
)
for k, v in data.items()
}
data = {k: _move_tensor(v) for k, v in data.items()}
else:
raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
return data
@ -246,3 +218,21 @@ def get_megatron_flops(
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
return tflops
class DummyProfile:
"""
Dummy Profile.
"""
def __init__(self, *args, **kwargs) -> None:
pass
def __enter__(self):
return self
def __exit__(self, a, b, c):
pass
def step(self):
pass

View File

@ -0,0 +1,168 @@
from contextlib import contextmanager
import torch
import torch.distributed as dist
from tqdm import tqdm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.model.metrics import AccPerplex
@contextmanager
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
if not gpc.is_using_pp():
prev_data_process_func = trainer.schedule.data_process_func
prev_grad_accum_size = trainer.schedule._grad_accum_size
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
prev_metric_hooks = trainer.schedule._hooks
try:
trainer.schedule.data_process_func = None
trainer.schedule._grad_accum_size = grad_accum_size
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
trainer.schedule._hooks = metric_hook_list
yield
finally:
trainer.schedule.data_process_func = prev_data_process_func
trainer.schedule._grad_accum_size = prev_grad_accum_size
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
trainer.schedule._hooks = prev_metric_hooks
@contextmanager
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
if gpc.is_using_pp():
pre_data_process_func = trainer.schedule.data_process_func
prev_num_microbatches = trainer.schedule.num_microbatches
prev_tensor_shape = trainer.schedule.tensor_shape
prev_metric_hooks = trainer.schedule._hooks
try:
trainer.schedule.data_process_func = None
trainer.schedule.num_microbatches = num_microbatches
trainer.schedule.tensor_shape = tensor_shape
trainer.schedule._hooks = metric_hook_list
yield
finally:
trainer.schedule.data_process_func = pre_data_process_func
trainer.schedule.num_microbatches = prev_num_microbatches
trainer.schedule.tensor_shape = prev_tensor_shape
trainer.schedule._hooks = prev_metric_hooks
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel
try:
gpc.config.parallel.sequence_parallel = False
yield
finally:
gpc.config.parallel.sequence_parallel = prev_mode
def evaluate_on_val_dls(
trainer,
val_dls,
writer,
logger,
step_count,
update_panel: bool = False,
streaming: bool = False,
):
with switch_sequence_parallel_mode():
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
data_cfg = gpc.config.data
for val_name, val_dl in val_dls.items():
if len(val_dl) == 0 and verbose and not streaming:
logger.info(f"Validation dataset: {val_name} is empty")
continue
val_metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
)
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
val_loss = 0
val_idx = -1
for val_idx, batch in tqdm(
enumerate(val_dl),
desc="Val.",
total=len(val_dl) if not streaming else None,
position=1,
disable=not verbose,
leave=False,
):
with torch.inference_mode():
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
num_microbatches=num_microbatches,
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item()
assert val_idx != -1
dist.barrier()
val_res = val_metric.get_metric()
if verbose and len(val_dl) != 0:
val_loss = val_loss / (val_idx + 1 + 1e-6)
infos = {
"step": step_count,
f"val/{val_name}_loss": val_loss,
f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"],
}
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
if update_panel:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
extra={
"step": step_count,
"val_loss": val_loss,
"val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"],
},
)
else:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
)
trainer.train()
torch.cuda.empty_cache()
dist.barrier()

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
import logging
import os
LOGGER_NAME = "internlm"
LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s"
@ -11,6 +12,8 @@ LOGGER_LEVEL_HELP = (
"The logging level threshold, choices=['debug', 'info', 'warning', 'error', 'critical'], default='info'"
)
uniscale_logger = None
def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL) -> logging.Logger:
"""Configure the logger that is used for uniscale framework.
@ -24,6 +27,10 @@ def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL
logger (logging.Logger): the created or modified logger.
"""
if uniscale_logger is not None:
return uniscale_logger
logger = logging.getLogger(logger_name)
if logging_level not in LOGGER_LEVEL_CHOICES:
@ -39,3 +46,53 @@ def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL
logger.addHandler(handler)
return logger
def initialize_uniscale_logger(
job_name: str = None,
launch_time: str = None,
file_name: str = None,
name: str = LOGGER_NAME,
level: str = LOGGER_LEVEL,
file_path: str = None,
is_std: bool = True,
):
"""
Initialize uniscale logger.
Args:
job_name (str): The name of training job, defaults to None.
launch_time (str): The launch time of training job, defaults to None.
file_name (str): The log file name, defaults to None.
name (str): The logger name, defaults to "internlm".
level (str): The log level, defaults to "info".
file_path (str): The log file path, defaults to None.
is_std (bool): Whether to output to console, defaults to True.
Returns:
Uniscale logger instance.
"""
try:
from uniscale_monitoring import get_logger as get_uniscale_logger
except ImportError:
print("Failed to import module uniscale_monitoring. Use default python logger.")
return None
if not file_path:
assert (
job_name and launch_time and file_name
), "If file_path is None, job_name, launch_time and file_name must be setted."
log_file_name = file_name
log_folder = os.path.join(job_name, launch_time, "logs")
log_dir = os.path.join(log_folder, log_file_name)
file_path = log_dir
logger = get_uniscale_logger(name=name, level=level, filename=file_path, is_std=is_std)
if isinstance(logger, (list, tuple)):
logger = list(logger)[0]
global uniscale_logger
uniscale_logger = logger
return logger

View File

@ -14,18 +14,19 @@ class _Timer:
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
self.stream = torch.cuda.current_stream()
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already been started"
torch.cuda.synchronize()
self.stream.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
torch.cuda.synchronize()
self.stream.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False

View File

@ -2,8 +2,11 @@
# -*- encoding: utf-8 -*-
import copy
import fcntl
import os
import socket
import time
from enum import Enum
from typing import Dict
import torch
@ -11,15 +14,26 @@ import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
from internlm.monitor import send_alert_message
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
from internlm.utils.storage_manager import (
get_fns,
get_storage_manager,
llm_load,
llm_save,
)
logger = get_logger(__file__)
class CheckpointType(Enum):
NORMAL_CHECKPOINT = 1
SNAPSHOT_CHECKPOINT = 2
def get_model_topology(model):
"""
Returns:
@ -138,11 +152,13 @@ def save_optimizer_checkpoint(optim, state_path):
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
states = optim.state_dict()
if isinstance(optim, HybridZeroOptimizer):
if gpc.get_global_rank() < optim.zero_world_size:
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
llm_save(os.path.join(state_path, fp), states)
if "zero_devide_optim_plan" in states:
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
@ -152,44 +168,6 @@ def save_optimizer_checkpoint(optim, state_path):
llm_save(os.path.join(state_path, fp), states)
def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
"""
Save checkpoint to the given folder path.
"""
start = time.time()
torch.distributed.barrier()
folder = os.path.join(folder, str(train_state.step_count))
logger.info(
f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..."
)
timer("save-model").start()
save_model_checkpoint(folder=folder, model=model)
timer("save-model").stop()
timer("save-optimizer").start()
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
timer("save-optimizer").stop()
if gpc.is_rank_for_log():
scheduler_states = scheduler.state_dict()
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
sampler_state = train_state.batch_sampler.state_dict()
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
if model_config is not None:
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
torch.distributed.barrier()
if gpc.is_rank_for_log():
timer.log(["save-model", "save-optimizer"], logger=logger)
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
def load_optimizer_checkpoint(folder, optim):
"""Load the optimizer state from the local file system or remote
object storage Service (OSS).
@ -287,3 +265,369 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
if gpc.is_rank_for_log():
logger.info(f"reload load_scheduler:{lr_scheduler}")
class CheckpointManager:
"""StorageManagerContext"""
def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
"""
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait
for the asynchronous ckpt upload to complete.
Args:
ckpt_config (dict): model checkpoint config.
model (nn.module): model obj
optimizer (object): optimzier obj.
lr_scheduler (object): lr_scheduler obj.
model_config (dict): model config.
"""
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
self.checkpoint_every = ckpt_config.checkpoint_every
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
self.stop_file_path = ckpt_config.stop_file_path
self.load_model_only_folder = ckpt_config.load_model_only_folder
self.feishu_address = feishu_address
self.storage_manager = get_storage_manager()
self.snapshot_counter = 0
self.load_optimizer = gpc.config.ckpt.load_optimizer
self.model = model
self.model_config = model_config
self.model_config_file = model_config_file
if self.stop_file_path and gpc.get_global_rank() == 0:
dir_path = os.path.dirname(self.stop_file_path)
if dir_path != "" and not os.path.exists(dir_path):
os.makedirs(dir_path)
with open(self.stop_file_path, "w", encoding="utf-8") as f:
f.write("0")
if ckpt_config.load_given_ckpt is False:
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
latest_ckpt_path = self.query_lastest_ckpt()
if latest_ckpt_path:
self.load_ckpt_folder = latest_ckpt_path
else:
# At this time, we have to load model init weights and train from step 0.
self.load_ckpt_folder = self.load_model_only_folder
else:
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
if gpc.is_rank_for_log():
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
if self.stop_file_path is None:
logger.warning("no set stop_file_path, quit_signal_handler is disable")
def quit_signal_handler(self, train_state) -> bool:
"""
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
all ranks will save ckpt and exit.
Negative integer step means save ckpt.
Positive integer step means save ckpt and quit.
Args:
train_state (TrainState):
Returns:
bool: whether to quit.
"""
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
if self.stop_file_path is None:
return now_break, now_save_ckpt, save_type
with open(self.stop_file_path, "a+", encoding="utf-8") as f:
fcntl.flock(f, fcntl.LOCK_EX)
f.seek(0)
msg = f.read()
fcntl.flock(f, fcntl.LOCK_UN)
action_step = int(msg)
if action_step < 0 and abs(action_step) == train_state.step_count:
now_save_ckpt = True
if action_step > 0 and action_step == train_state.step_count:
now_break, now_save_ckpt = True, True
if action_step != 0 and gpc.is_rank_for_log():
msg = "Stop" if action_step > 0 else "Save"
action_step = abs(action_step)
if train_state.step_count <= action_step:
if self.feishu_address:
send_alert_message(
address=self.feishu_address,
message=f"training will {msg} at step_count {action_step}!\
now step_count is {train_state.step_count}",
)
return now_break, now_save_ckpt, save_type
def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return False
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0:
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
if save_ckpts is False:
save_ckpts = singal_save_ckpts
save_type = singal_save_type
if save_ckpts:
# Wait for the previous round of asynchronous upload storage to complete.
self.storage_manager.wait()
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
# Snapshot number, with only two snapshots written alternately.
self.snapshot_counter = (self.snapshot_counter + 1) % 2
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
else:
save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count))
self.save_checkpoint(
folder=save_ckpt_folder,
model=self.model,
optimizer=self.optimizer,
scheduler=self.lr_scheduler,
train_state=train_state,
model_config=self.model_config,
model_config_file=self.model_config_file,
)
return now_break
def wait_async_upload_finish(self):
"""wait for all checkpoint uploads to be completed"""
self.storage_manager.wait()
torch.distributed.barrier()
def query_latest_snapshot_step_boto3(self):
"""query_latest_snapshot_step_boto3
Returns:
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
"""
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
if len(ckpt_list) == 0:
return None, None
max_normal_step = 0
ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list))
ckpt_list.sort(reverse=True)
for ckpt in ckpt_list:
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
for fn in fns_list:
if fn.endswith(".step"):
max_normal_step = ckpt
break
if max_normal_step != 0:
break
max_normal_step = ckpt_list[0]
load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step))
snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0")
snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1")
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
max_step_0, max_step_1 = 0, 0
for ckpt in ckpt_list_1:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
for ckpt in ckpt_list_2:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
snap_step = max(max_step_0, max_step_1)
load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path
load_step = max(snap_step, max_normal_step)
return load_path, load_step
def query_latest_snapshot_step_local(self):
max_step, max_step_path = 0, None
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
for fn in files:
fn = fn.strip("/")
if fn.endswith(".step"):
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
# as an integrity flag.
step = int(fn.rsplit(".", maxsplit=1)[0])
if max_step < step:
max_step = step
max_step_path = root
return max_step_path, max_step
def query_lastest_ckpt(self):
latest_checkpoint = None
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
if self.save_ckpt_folder:
if self.save_ckpt_folder.startswith("boto3"):
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
elif self.save_ckpt_folder.startswith("local"):
latest_checkpoint, step = self.query_latest_snapshot_step_local()
else:
latest_checkpoint, step = None, 0
if latest_checkpoint is not None:
if gpc.is_rank_for_log():
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
send_alert_message(
address=self.feishu_address,
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
)
else:
if gpc.is_rank_for_log():
send_alert_message(
address=self.feishu_address,
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
)
return latest_checkpoint
def try_load_model(self, current_time=""):
model_load_path = None
if self.load_ckpt_folder and self.load_model_only_folder:
raise ValueError(
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
if you only need to load model weights (for example starting an SFT task for the first time), \
set load_model_only_folder path, if you need to resume training from ckpt, \
set load_ckpt_folder or use default value \
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
)
if self.load_ckpt_folder:
if gpc.is_rank_for_log():
logger.info(
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = self.load_ckpt_folder
elif self.load_model_only_folder:
if gpc.is_rank_for_log():
logger.info(
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = self.load_model_only_folder
else:
if gpc.is_rank_for_log():
logger.info(
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
# Loading model weights must be done before zero is initialized.
if model_load_path is not None:
load_model_checkpoint(folder=model_load_path, model=self.model)
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
"""Attempt to restore the training state of the last ckpt.
Args:
lr_scheduler (_LRScheduler): lr_scheduler object.
optimizer (Optimizer): optimizer object.
lr (float): learning rate.
train_state (dict): traing states.
train_dl (DataLoader): traning dataloader object
"""
if self.load_ckpt_folder is not None:
# load optimzier states.
if self.load_optimizer:
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
# load lr scheduler states.
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
# load training states.
load_context(self.load_ckpt_folder, train_dl, train_state)
# load dataloader sampler states.
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
if hasattr(train_state, "data_state_dict"):
train_dl.dataset.load_state_dict(
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def save_checkpoint(
self,
folder,
model,
optimizer,
scheduler,
train_state: TrainState,
model_config: Dict = None,
model_config_file: str = None,
):
"""
Save checkpoint to the given folder path.
"""
start = time.time()
self.set_save_folder(folder, train_state.step_count)
torch.cuda.synchronize()
torch.distributed.barrier()
if gpc.is_rank_for_log():
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")
timer("save-model").start()
save_model_checkpoint(folder=folder, model=model)
timer("save-model").stop()
timer("save-optimizer").start()
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
timer("save-optimizer").stop()
if (
hasattr(train_state, "data_state_dict")
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
and gpc.get_local_rank(ParallelMode.PIPELINE) == 0
):
llm_save(
os.path.join(folder, f"sampler_{gpc.get_local_rank(ParallelMode.DATA)}.pt"),
saved_obj=train_state.data_state_dict,
)
if gpc.is_rank_for_log():
scheduler_states = scheduler.state_dict()
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
sampler_state = train_state.batch_sampler.state_dict()
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
if model_config is not None:
# Model configuration dictionary.
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
if model_config_file is not None:
# The complete training config file content, stored in binary format.
llm_save(os.path.join(folder, "config_file.pt"), saved_obj=model_config_file)
torch.distributed.barrier()
if gpc.is_rank_for_log():
timer.log(["save-model", "save-optimizer"], logger=logger)
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
if self.storage_manager.async_mode is False:
llm_save(
os.path.join(folder, f"{train_state.step_count}.step"),
saved_obj=dict({"step": train_state.step_count}),
)
def set_save_folder(self, folder, step):
self.storage_manager.latest_save_folder = folder
self.storage_manager.latest_save_step = step

View File

@ -46,3 +46,16 @@ def sync_model_param_within_tp(model):
def is_no_pp_or_last_stage():
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
def get_parallel_log_file_name():
if gpc.is_rank_for_log():
fn_prefix = "main_" # Indicates a rank with more output information
else:
fn_prefix = ""
log_file_name = (
f"{fn_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}"
)
return log_file_name

View File

@ -22,9 +22,9 @@ class Registry:
"""Registers a module represented in `module_class`.
Args:
module_class (class): The module to be registered.
module_name (str): The name of module to be registered.
Returns:
class: The module to be registered, so as to use it normally if via importing.
function: The module to be registered, so as to use it normally if via importing.
Raises:
AssertionError: Raises an AssertionError if the module has already been registered before.
"""

View File

@ -1,15 +1,13 @@
import os
import time
from collections import OrderedDict
from functools import partial
from functools import partial, reduce
from typing import Any, Dict, List, Tuple
import pyecharts
import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.solver.pipeline_utils import partition_uniform
from internlm.core.naive_amp import NaiveAMPModel
mb = 1024 * 1024
@ -107,6 +105,8 @@ class SimpleMemState:
"""
Update the total memory usage of the model and sub-models.
"""
self._total_mem = self._layer_mem
for stat in self.sub_model_stats.values():
# Update sub-model status first.
stat.update_total_memory()
@ -169,6 +169,39 @@ class SimpleMemState:
return {"name": self.layer_name, "children": children}
class ActivationMemState:
"""
Activation Memory State
"""
def __init__(self, num_chunks: int) -> None:
self._num_chunks = num_chunks
self.inited: List[bool] = [False for _ in range(num_chunks)]
self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)]
@property
def total_mem(self) -> int:
return sum(state.total_mem for state in self.states)
def dump(self, prefix: str = "") -> str:
return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states])
def to_json(self, base: int = 1024 * 1024) -> List:
return [state.to_json(base) for state in self.states]
def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]:
num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1
if num_chunks > 1:
model = torch.nn.ModuleList([_model.model if isinstance(_model, NaiveAMPModel) else _model for _model in model])
else:
model = model.model if isinstance(model, NaiveAMPModel) else model
return model, num_chunks
class SimpleMemoryProfiler:
"""
A memory profiler for a llm model.
@ -177,7 +210,7 @@ class SimpleMemoryProfiler:
model (torch.nn.Module): The model to profile.
optimizer (torch.optim.Optimizer): The optimizer used for training the model.
log_file (str): The file to write the memory state information to.
activation_config (List[str], optional): The list of activation layers to track. Defaults to None.
total_steps: number of steps to trace.
"""
def __init__(
@ -186,9 +219,8 @@ class SimpleMemoryProfiler:
optimizer: torch.optim.Optimizer,
log_folder: str,
total_steps: int = 5,
activation_config: List[str] = None,
):
self._model = model
self._model, self._num_model_chunks = _unpack_naive_wrapper(model)
self._optimizer = optimizer
self._log_folder = log_folder
self._remaining_steps = total_steps
@ -197,17 +229,20 @@ class SimpleMemoryProfiler:
self._record_start_time = time.time()
# For activation memory state.
self._activation_config = activation_config
self._activation_mem_inited: bool = False
self._activation_mem: int = 0
self._activation_max_count = 0
self._activation_base_mem: SimpleMemState = SimpleMemState("activations")
self._activation_mem_max: int = 0
self._activation_base_mems = ActivationMemState(self._num_model_chunks)
# Check or create log folder
os.makedirs(self._log_folder, exist_ok=True)
# Register activation memory tracking hooks
self._register_activation_trace_hooks()
if self._num_model_chunks > 1:
for chunk_id in range(self._num_model_chunks):
self._register_activation_trace_hooks(chunk_id, self._model[chunk_id])
else:
self._register_activation_trace_hooks(0, self._model)
# Calculate static parameter cuda memory
self._param_mem_state = SimpleMemState("param_mem")
@ -221,7 +256,7 @@ class SimpleMemoryProfiler:
self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups)))
# Generate the first memory record
self.point(create=True)
self.point(with_options="params,grads,os_params", create=True)
def point(self, with_options: str = "", create: bool = False) -> None:
"""
@ -272,7 +307,7 @@ class SimpleMemoryProfiler:
if "os_state" in options:
layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump()
if "activation_base" in options:
layout_info += "activation_base_layout:\n" + self._activation_base_mem.dump()
layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump()
# Write memory state information to log file
file_mode = "w" if create else "a"
@ -315,14 +350,14 @@ class SimpleMemoryProfiler:
[self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()],
"os_memory_sunburst",
)
self._render_sunburst_chart(self._activation_base_mem.to_json()["children"], "activation_memory_sunburst")
self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst")
# Generate summary sunburst chart
summary_sunburst_data = [
{"name": "params", "value": self._param_mem_state.total_mem // mb},
{"name": "grads", "value": self._grad_mem_state.total_mem // mb},
{"name": "os_params", "value": self._os_params_mem_state.total_mem // mb},
{"name": "os_state", "value": self._os_state_mem_state.total_mem // mb},
{"name": "activation", "value": self._activation_base_mem.total_mem // mb},
{"name": "activation", "value": self._activation_mem_max // mb},
]
self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst")
@ -337,12 +372,13 @@ class SimpleMemoryProfiler:
{},
{
"r0": "10%",
"r": "40%",
"r": "35%",
"itemStyle": {"borderWidth": 3},
"label": {"align": "left"},
},
{"r0": "40%", "r": "65%", "label": {"align": "left"}},
{"r0": "65%", "r": "80%", "label": {"align": "left"}},
{"r0": "35%", "r": "55%", "label": {"align": "left"}},
{"r0": "55%", "r": "70%", "label": {"align": "left"}},
{"r0": "70%", "r": "80%", "label": {"align": "left"}},
{"r0": "80%", "r": "90%", "label": {"align": "left"}},
{
"r0": "90%",
@ -357,7 +393,14 @@ class SimpleMemoryProfiler:
f"{self._log_folder}/{name}.html"
)
def _inner_activation_trace_hook(self, layer_name: str, model: Any, inputs: Any, output: torch.Tensor) -> None:
def _inner_activation_trace_hook(
self,
chunk_id: int,
layer_name: str,
model: Any,
inputs: Any,
output: torch.Tensor,
) -> None:
"""
Hook function to trace the activation memory usage for a inner layer.
@ -373,13 +416,15 @@ class SimpleMemoryProfiler:
del model, inputs
assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}"
if self._stoped or self._activation_mem_inited:
if self._stoped or self._activation_base_mems.inited[chunk_id]:
return
# Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook.
self._activation_base_mem.add(layer_name, output.element_size() * output.nelement(), flush=False)
self._activation_base_mems.states[chunk_id].add(
layer_name, output.element_size() * output.nelement(), flush=False
)
def _activation_trace_hook_forward(self, model: Any, inputs: Any, output: torch.Tensor) -> None:
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
"""
Hook function to trace the activation memory usage for a forward pass.
@ -398,23 +443,24 @@ class SimpleMemoryProfiler:
return
# Check if the activation memory has been initialized
if self._activation_mem_inited is False:
if self._activation_base_mems.inited[chunk_id] is False:
self._activation_base_mems.inited[chunk_id] = True
# Update the total memory of the activation base memory state
self._activation_base_mem.update_total_memory()
self._activation_base_mems.states[chunk_id].update_total_memory()
# Set with_options to "activation_base" to include activation_base_layout in the memory dump
self._activation_mem_inited = True
with_options = "activation_base"
else:
with_options = ""
# Accumulate activation memory usage for each forward pass
self._activation_mem += self._activation_base_mem.total_mem
# Update activation max count
if self._activation_mem // self._activation_base_mem.total_mem > self._activation_max_count:
self._activation_max_count = self._activation_mem // self._activation_base_mem.total_mem
self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem
if self._activation_mem > self._activation_mem_max:
self._activation_mem_max = self._activation_mem
# Trigger a memory record
self.point()
self.point(with_options)
def _activation_tarce_hook_backward(self, model: Any, inputs: Any, grad_outputs: Any) -> None:
def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None:
"""
Hook function to trace the activation memory usage for a backward pass.
@ -432,37 +478,28 @@ class SimpleMemoryProfiler:
return
# Release activation memory usage for each backward pass
self._activation_mem -= self._activation_base_mem.total_mem
self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem
# Trigger a memory record
self.point()
def _register_activation_trace_hooks(self) -> None:
def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None:
"""
Register activation trace hooks for the model and each submodule in the model.
"""
# Register inner activation trace hooks for each submodule in the model
for layer_name in self._activation_config:
# Register a hook for every activation
model = self._model
sub_models = layer_name.split(".")
# Get the target sub-model
for sub_model_name in sub_models:
try:
model = model.get_submodule(sub_model_name)
except AttributeError:
model = None
break
for layer_name, sub_model in model_chunk.named_modules():
# Register the hook
if model is not None:
model.register_forward_hook(partial(self._inner_activation_trace_hook, layer_name))
if len(sub_model._modules) != 0:
continue # TODO: in some special cases, we may need some additional configuration to correct
sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name))
# Register a forward hook for the main model to track activation memory usage
self._model.register_forward_hook(self._activation_trace_hook_forward)
model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id))
# Register a backward hook for the main model to release activation memory usage
self._model.register_full_backward_hook(self._activation_tarce_hook_backward)
model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id))
def _calc_tensor_memory(
self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False
@ -554,48 +591,6 @@ class SimpleMemoryProfiler:
self._calc_tensor_memory(root_stat, named_tensors)
def build_activation_config(num_layers: int, num_chunks: int = 1) -> List[str]:
# TODO: support interleaved pipeline scheduling.
assert num_chunks == 1, "Only support num_chunks == 1"
if gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
else:
pipeline_size = 1
pipeline_rank = 0
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank]
start, end = parts[0]
num_blocks = end - start
block_conf_tmpl = [
"mixer.rotary_emb",
"mixer.Wqkv",
"mixer.inner_attn",
"mixer.inner_cross_attn",
"mixer.out_proj",
# "dropout1", # skip when dropout_selective_checkpoint is True
# "dropout2", # skip when dropout_selective_checkpoint is True
"norm1",
"norm2",
"mlp.w1",
"mlp.w2",
"mlp.w3",
]
block_conf = []
for block_id in range(num_blocks):
block_conf += [f"blocks.{block_id}.{layer}" for layer in block_conf_tmpl]
# We don't need to care about whether the embedding, norm, and head layers exist in the model after partitioning.
# If they don't exist, they will be automatically ignored when registering activation trace hooks.
activation_conf = ["embedding", "norm", "head"] + block_conf
return activation_conf
if __name__ == "__main__":
class SimpleModel(torch.nn.Module):
@ -635,32 +630,39 @@ if __name__ == "__main__":
return output
def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor:
if _num_chunks > 1:
_output = _input
for _model_chunk in _model_chunks:
_output = _model_chunk(_output)
else:
_output = _model_chunks(_input)
return _output
# num_chunks config
_num_chunks = 1
# init model and optimizer
_model: torch.nn.Module = SimpleModel()
if _num_chunks > 1:
_chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)]
_model = torch.nn.ModuleList(_chunks).cuda()
else:
_model: torch.nn.Module = SimpleModel().cuda()
_optimizer = torch.optim.Adam(_model.parameters())
# create activation config for simple model layer by layer.
activation_configs = [
# model level 0
"layer1",
"layer2",
"layer3",
# model level 1
"layer2.layer1",
"layer2.layer3",
]
_model.modules()
# init profiler
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler.log", activation_configs)
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler", total_steps=1)
_optimizer.zero_grad()
x1 = torch.randn((128, 5120))
x2 = torch.randn((128, 5120))
out1 = _model(x1)
out2 = _model(x2)
# inputs
x1 = torch.randn((128, 5120)).cuda()
x2 = torch.randn((128, 5120)).cuda()
# forward
out1 = _simple_schedule(_num_chunks, _model, x1)
out2 = _simple_schedule(_num_chunks, _model, x2)
# backward
out1.mean().backward()
out2.mean().backward()

View File

@ -1,21 +1,34 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import asyncio
import concurrent.futures
import hashlib
import io
import os
import pickle
import re
import socket
from enum import Enum
from typing import Any, Dict, List, Union
import stat
from asyncio import InvalidStateError
from asyncio.tasks import ALL_COMPLETED
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Union
import boto3
import botocore
import torch
import torch.distributed as dist
from internlm.core.context import global_context as gpc
from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger
try:
import boto3
import botocore
except ImportError:
pass
logger = get_logger(__file__)
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
@ -41,10 +54,6 @@ def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
class CheckpointType(Enum):
NORMAL_CHECKPOINT = 1
class StorageClient:
"""
StorageClient as a client for s3 storage access.
@ -54,7 +63,7 @@ class StorageClient:
self.handler = handler
@staticmethod
def load(client, load_path: str, map_location):
def load(client, load_path: str, *args, **kwargs):
raise NotImplementedError
@staticmethod
@ -71,25 +80,51 @@ class StorageClient:
class Boto3MetaInfo:
def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None:
self.client = client
"""Boto3 meta info for save/load etc."""
def __init__(
self,
is_async,
handler: StorageClient,
bucket_name: str,
endpoint: str,
file_path: str,
async_upload_fn: callable,
local_nvme_path=None,
) -> None:
self.is_async = is_async
self.client = handler
self.bucket_name = bucket_name
self.endpoint = endpoint
self.file_path = file_path
self.async_upload_fn = async_upload_fn
self.local_nvme_path = local_nvme_path
def __str__(self) -> str:
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
local_nvme_path: {self.local_nvme_path}"
class LocalMetaInfo:
def __init__(self, client: StorageClient, dest_path: str) -> None:
self.client = client
"""Local meta info for save/load etc."""
def __init__(self, handler: StorageClient, dest_path: str) -> None:
self.is_async = False
self.client = handler
self.dest_path = dest_path
self.async_upload_fn = None
def unpack_meta(meta):
args = []
is_async = meta.is_async
for k, v in meta.__dict__.items():
if k == "endpoint":
if k in ("endpoint", "async_upload_fn", "is_async"):
continue
if not is_async and k in ("local_nvme_path",):
continue
args.append(v)
return args
@ -101,21 +136,6 @@ def compute_file_md5_by_chunk(file_name: str):
return hash_md5.hexdigest()
def get_boto3_meta(fp: str) -> Boto3MetaInfo:
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
parts = fp.lstrip("s3://").split(os.path.sep)
match = boto3_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid boto3 url"
bucket_name, endpoint = match.group(1), match.group(2)
endpoint = "http://" + endpoint + ":80"
return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:]))
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(None, fp)
class Boto3Client(StorageClient):
"""
Boto3Client
@ -169,7 +189,9 @@ class Boto3Client(StorageClient):
)
@staticmethod
def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs):
def sync_upload_fileobj(
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
): # pylint: disable=W0613
assert saved_obj is not None, "saved_obj is None!"
try:
with io.BytesIO() as f:
@ -182,7 +204,14 @@ class Boto3Client(StorageClient):
) from exc
@staticmethod
def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict:
def load(
handler,
bucket_name: str,
fp: str,
local_nvme_path: str, # pylint: disable=W0613
*args,
**kwargs,
) -> Dict:
"""
Args:
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
@ -191,7 +220,7 @@ class Boto3Client(StorageClient):
with io.BytesIO() as f:
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
f.seek(0)
states = torch.load(f, *args, map_location=map_location, **kwargs)
states = torch.load(f, *args, **kwargs)
except handler.botocore.exceptions.EndpointConnectionError as exc:
raise RuntimeError(
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
@ -199,28 +228,40 @@ class Boto3Client(StorageClient):
return states
@staticmethod
def assert_fp_exists(
handler,
bucket_name: str,
fp: str,
):
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
@staticmethod
def get_fns(handler, bucket_name: str, fp: str):
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
"""
Ref: https://stackoverflow.com/questions/54314563/
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
"""
paginator = handler.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
folder_name_list = []
for page in pages:
for obj in page["Contents"]:
fp: str = obj["Key"]
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
return folder_name_list
if "Contents" in page:
for obj in page["Contents"]:
pth: str = obj["Key"]
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
@staticmethod
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
try:
with open(local_nvme_path, "rb") as f:
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
except handler.botocore.exceptions.EndpointConnectionError as exc:
raise RuntimeError(
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
) from exc
except Exception as e:
raise e
@staticmethod
def delete_obj(handler, fp: str):
raise NotImplementedError("boto3 not support delete_obj")
class LocalClient(StorageClient):
@ -241,11 +282,11 @@ class LocalClient(StorageClient):
torch.save(saved_obj, fp, *args, **kwargs)
@staticmethod
def load(handler, fp: str, *args, map_location="cpu", **kwargs):
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
assert isinstance(handler, LocalClient)
assert os.path.exists(fp), f"{fp} is not found!"
with open(fp, "rb") as f:
states = torch.load(f, map_location=map_location, *args, **kwargs)
states = torch.load(f, *args, **kwargs)
return states
@staticmethod
@ -267,9 +308,77 @@ class LocalClient(StorageClient):
os.remove(fp)
def get_tmp_file_name(tmp_local_folder: str, fp: str):
"""
It should be noted that all our temporary files will be stored in the same folder,
so the file name passed upstream must be unique.
"""
base_path = os.path.join(tmp_local_folder, fp.split("/")[-1])
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
pid = os.getpid()
# step = self.step_counter
return "-".join([base_path, current_time, str(pid)]) + ".tmpfile" # , str(step)
def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaInfo:
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
parts = fp.lstrip("s3://").split(os.path.sep)
match = boto3_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid boto3 url"
bucket_name, endpoint = match.group(1), match.group(2)
endpoint = "http://" + endpoint + ":80"
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
return Boto3MetaInfo(
is_async=is_async,
handler=None,
bucket_name=bucket_name,
endpoint=endpoint,
file_path=os.path.sep.join(parts[1:]),
async_upload_fn=Boto3Client.async_upload_fileobj,
local_nvme_path=tmp_step_file,
)
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(None, fp)
def get_mount_point_free_size(path: str):
"""
Returns the remaining space of the temporary storage mount point as a percentage.
Args:
path (str): temporary storage folder path.
Raises:
FileNotFoundError: If the temporary storage folder does not exist,
an error will be reported
"""
if os.path.exists(path):
st = os.statvfs(path)
# f_bavail: Number of free blocks for unprivileged users.
# f_bsize: Filesystem block size.
# return unit is TB.
return st.f_bavail * st.f_bsize / (1024**3)
def check_tmp_folder_accessibility(tmp_local_folder: str):
"""
Check access permissions for temporary storage.
"""
ret = True
if os.path.exists(tmp_local_folder):
ret &= os.access(tmp_local_folder, os.W_OK)
ret &= os.access(tmp_local_folder, os.R_OK)
if ret is False:
error_str = f'{socket.gethostname()} dose not have read and write permissions on {tmp_local_folder}"'
raise RuntimeError(error_str)
class StorageManager(metaclass=SingletonMeta):
"""
Storage Manager for saving or loading checkpoint.
TODO: add a thread to poll the asynchronous storage state.
"""
BACKEND_TYPE = {"boto3", "local"}
@ -279,8 +388,44 @@ class StorageManager(metaclass=SingletonMeta):
}
CLI_DICT = {}
def __init__(self) -> None:
pass
def __init__(self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None:
self._exception_list = []
self._to_be_del_files = []
self._async_stack = []
self.upload_count = 0
self.tmp_local_folder = tmp_local_folder
self.async_mode = async_mode
self.has_warning = False
self._async_loop = None
self._thread_pool = None
self.latest_save_folder = None
self.latest_save_step = 0
self.async_task_peeding = False
if enable_save and self.async_mode:
self._async_loop = asyncio.new_event_loop()
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder))
# Try to create tmp folder
try:
os.makedirs(self.tmp_local_folder, exist_ok=True)
os.chmod(self.tmp_local_folder, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
except FileExistsError:
pass
# In case it is a directory created by other users, we check the permissions again.
check_tmp_folder_accessibility(self.tmp_local_folder)
# Try to clean tmp folder's empty folder.
self.try_delete_tmpfile(self.tmp_local_folder)
# Avaliable storeage space check.
free_size = get_mount_point_free_size(self.tmp_local_folder)
if free_size < 0.1:
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
"""
@ -301,7 +446,7 @@ class StorageManager(metaclass=SingletonMeta):
meta_info = get_local_meta(path)
backend_key = backend
elif backend == "boto3":
meta_info = get_boto3_meta(path)
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
backend_key = backend + ":" + meta_info.endpoint
init_args = (meta_info.endpoint,)
if (
@ -310,10 +455,12 @@ class StorageManager(metaclass=SingletonMeta):
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
raise RuntimeWarning(
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance."
)
if not self.has_warning:
logger.warning(
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance."
)
self.has_warning = True
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
@ -333,19 +480,145 @@ the proxy may make boto3 unavailable or affect performance."
meta = self._get_client(path=folder)
return meta.client.get_fns(*unpack_meta(meta))
def save(self, save_path: str, saved_obj: Any, *args, **kwargs):
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
meta = self._get_client(path=save_path)
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any:
if async_upload is None:
async_upload = self.async_mode
if async_upload:
assert (
self.tmp_local_folder
), "StorageManager is not setted tmp_local_folder, so async save cannot be performed."
tmp_step_file = meta.local_nvme_path
self._to_be_del_files.append(tmp_step_file)
with open(tmp_step_file, "wb") as f:
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
self.async_task_peeding = True
else:
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
self.upload_count += 1
def load(self, load_path: str, *args, **kwargs) -> Any:
self.wait()
meta = self._get_client(path=load_path)
return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs)
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
def delete_obj(self, fp: str):
meta = self._get_client(path=fp)
meta.client.delete_obj(*unpack_meta(meta))
def _del_tmp_folder(self):
for fp in self._to_be_del_files:
try:
os.remove(fp)
except FileNotFoundError:
pass
except SystemError as e:
logger.error(f'delete file: {fp}, failed for reason:"{e}"')
else:
pass
storage_manager = StorageManager()
def try_delete_tmpfile(self, tmp_dir: str):
"""Delete temporary files in tmp_dir."""
for filename in os.listdir(tmp_dir):
if filename.endswith(".tmpfile"):
file_path = os.path.join(tmp_dir, filename)
try:
os.remove(file_path)
logger.info(f"Delete tmpfile: {file_path}")
except OSError:
# Ignore deletion errors
pass
async def _sync_tasks(self) -> Awaitable[None]:
if self._async_stack:
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
count = 0
while self._async_stack:
t = self._async_stack[0]
try:
e = t.exception()
if e:
self._exception_list.append((e, count))
logger.error(f"File:{self._to_be_del_files[count]}, upload failed for {e}")
# raise e
count += 1
self._async_stack.pop(0)
except InvalidStateError:
# Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception
pass
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
"""
Overview:
Execute task in background, then apppend the future instance in _async_stack.
Arguments:
- fn (:obj:`Callable`): Synchronization fuction.
"""
if not self._async_loop:
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode")
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
self._async_stack.append(t)
def wait(self) -> bool:
"""Wait for async operations to complete."""
if not self.async_mode:
return
if not self.async_task_peeding:
return
if self._async_loop:
self._async_loop.run_until_complete(self._sync_tasks())
if self._exception_list:
for error_msg, file_id in self._exception_list:
logger.error(
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
f"failed on step {self.upload_count}: {error_msg}"
)
# TODO: Re-upload in sync mode
raise RuntimeError(
f"Failed to upload {self._to_be_del_files[file_id]} " f"on step {self.upload_count}: {error_msg}"
)
self._del_tmp_folder()
self._exception_list.clear()
self._to_be_del_files.clear()
self.async_task_peeding = False
if gpc.is_rank_for_log():
self.upload_count += 1
if self.async_mode:
self.save(
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
saved_obj=dict({"step": self.latest_save_step}),
async_upload=False,
)
storage_manager: StorageManager = None
def init_storage_manager(ckpt_config):
global storage_manager
storage_manager = StorageManager(
ckpt_config.enable_save_ckpt,
tmp_local_folder=ckpt_config.async_upload_tmp_folder,
async_mode=ckpt_config.async_upload,
)
def get_storage_manager():
assert storage_manager is not None, "storage_manager has not been init!"
return storage_manager
def wait_async_upload_finish():
dist.barrier()
storage_manager.wait()

142
internlm/utils/writer.py Normal file
View File

@ -0,0 +1,142 @@
import logging
import os
import socket
import sys
import traceback
from functools import partial
import torch
from torch.utils.tensorboard import SummaryWriter
from internlm.core.context import global_context as gpc
def tb_save_run_info(writer, config_lines, global_step=0):
writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step)
lines = []
for line in config_lines:
if line.strip().startswith("#"):
continue
lines.append(line)
writer.add_text(tag="config", text_string="\n".join(lines), global_step=global_step)
def init_tb_writer(
job_name: str,
launch_time: str,
file_name: str,
tensorboard_folder: str,
resume_tb_folder: str,
step_count: int,
config: str,
logger: logging.Logger,
):
tb_log_file_name = file_name
if not tensorboard_folder:
tb_folder = os.path.join(job_name, launch_time, "tensorboards")
else:
tb_folder = tensorboard_folder
if gpc.get_global_rank() == 0:
# If we don't load ckpt, 'resume_tb_folder' is set as the tensorboard
# dir of the last task by 'make_launch_script.sh'.
# If we load ckpt, 'resume_tb_folder' will be overwritten as the
# reloaded 'train_state.resume_tb_folder'.s
if resume_tb_folder is not None:
assert len(resume_tb_folder) > 0 and resume_tb_folder != "/"
if not os.path.exists(resume_tb_folder):
logger.error(
f"Can't found resume_tb_folder{resume_tb_folder}, \
please make sure this folder is located at local file system."
)
else:
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}... ")
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
os.system(f"chmod -R +w {tb_folder}/")
else:
logger.info(f"Login tensorboard logs to: {tb_folder}")
tb_logdir = os.path.join(tb_folder, tb_log_file_name)
writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3)
writer.add_text(tag="job_name", text_string=job_name, global_step=step_count)
writer.add_text(tag="tensorboard_folder", text_string=tb_logdir, global_step=step_count)
torch.distributed.broadcast_object_list([tb_folder], src=0)
else:
objects = [None]
torch.distributed.broadcast_object_list(objects, src=0)
tb_folder = objects[0]
tb_logdir = os.path.join(tb_folder, tb_log_file_name)
writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3)
if gpc.is_rank_for_log():
tb_save_run_info(
writer=writer,
config_lines=config,
global_step=step_count,
)
writer.add_text(
tag=f"mapping_{tb_log_file_name}",
text_string=f"file_path={tb_logdir} hostname={socket.gethostname()} device={torch.cuda.current_device()}",
global_step=step_count,
)
writer.add_scaler = partial(writer.add_scalar, new_style=True)
return writer, tb_logdir
class Writer:
"""
Customed writer based on tensorboard for recording training metrics.
Args:
job_name (str): The name of training job, defaults to None.
launch_time (str): A string representing the launch time of the training.
file_name (str): The log file name, defaults to None.
tensorboard_folder (str): A string representing the folder for saving tensorboard logs.
resume_tb_folder (str): A string representing the folder for resuming tensorboard logs.
step_count (int): An integer representing the step count of the training.
config (str): A string representing the configuration of the training.
logger (logging.Logger): A logging.Logger object for logging information during training.
enable_tb (bool): A boolean indicating whether to enable the tensorboard writer.
"""
def __init__(
self,
job_name: str = None,
launch_time: str = None,
file_name: str = None,
tensorboard_folder: str = None,
resume_tb_folder: str = None,
step_count: int = 0,
config: str = None,
logger: logging.Logger = None,
enable_tb: bool = True,
) -> None:
self.enable_tb = enable_tb
self.tb_writer, self.tb_logdir = init_tb_writer(
job_name=job_name,
launch_time=launch_time,
file_name=file_name,
tensorboard_folder=tensorboard_folder,
resume_tb_folder=resume_tb_folder,
step_count=step_count,
config=config,
logger=logger,
)
def add_scalar(self, key, value, step):
try:
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
except Exception:
traceback.print_exc()
def add_text(self, key, value, step):
try:
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_text(tag=key, text_string=value, global_step=step)
except Exception:
traceback.print_exc()

View File

@ -30,7 +30,6 @@ $ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_
```
可以通过运行以下命令来生成`bin`和`meta`文件:
```bash
$ python tools/tokenizer.py --text_input_path raw_data.txt --bin_output_path cn/output.bin
```

View File

@ -14,7 +14,6 @@ This directory provide some tools for model training with the following file str
We need to use a `tokenizer` to generate `bin` and `meta` files for raw data. We import the tokenizer model by specifying the model weight path in `tools/tokenizer.py`. Currently, we provide `V7.model` to generate tokens. If you want to use a different model, you can modify the model weight path in `tokenizer.py` directly.
We can run the following command to generate `bin` and `meta` files corresponding to the original data. The parameter `text_input_path` represents the path of the original text data, currently supporting `txt`, `json`, and `jsonl` formats, while `bin_output_path` represents the save path of the generated `bin` files.
```bash
$ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_path your_output_bin_path
```

586
train.py
View File

@ -5,342 +5,76 @@ import socket
import time
import traceback
from functools import partial
from typing import Iterable
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler
from internlm.data.collaters import packed_collate_fn
from internlm.data.dummy_dataset import RandomDataset
from internlm.data.packed_dataset import (
PackedDataset,
PackedDatasetWithoutCuSeqlen,
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP
from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message
from internlm.monitor.monitor import monitor_manager as mm
from internlm.train import (
get_train_data_loader,
get_validation_data_loader,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
)
from internlm.utils.common import (
BatchSkipper,
get_master_node,
get_megatron_flops,
get_process_rank,
launch_time,
parse_args,
)
from internlm.utils.logger import get_logger
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import (
load_context,
load_model_checkpoint,
load_optimizer_checkpoint,
load_sampler,
load_scheduler,
save_checkpoint,
)
from internlm.utils.parallel import (
is_no_pp_or_last_stage,
sync_model_param,
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.simple_memory_profiler import (
SimpleMemoryProfiler,
build_activation_config,
)
from internlm.utils.model_checkpoint import CheckpointManager
from internlm.utils.parallel import get_parallel_log_file_name
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
from internlm.utils.writer import Writer
# global llm logger
logger = get_logger(__file__)
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
def initialize_llm_logger(start_time: str):
"""
Initialize distributed environment for distributed training.
Initialize customed uniscale logger.
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
start_time (str): The launch time of current training job.
Returns: The instance of uniscale logger.
"""
torch.cuda.empty_cache()
if launcher == "torch":
internlm.launch_from_torch(config=config, seed=seed)
elif launcher == "slurm":
internlm.launch_from_slurm(
config=config,
host=get_master_node(),
port=master_port,
seed=seed,
)
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
def initialize_model():
"""
Initialize model.
Returns: The neural network model to be trained or evaluated.
"""
assert (
not hasattr(gpc.config.parallel, "pipeline") or gpc.config.parallel.pipeline == 1
), "Pipeline parallelism is not supported for now."
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
model = NaiveAMPModel(
model=model,
output_to_fp32=is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
uniscale_logger = initialize_uniscale_logger(
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
)
if uniscale_logger is not None:
global logger
logger = uniscale_logger
# This sync is very important, cause the model weights kept in optimizer are copied
# from the origin parameters in the memory, so we should make sure the dp sync
# does not influence the model weights in optimizer be different with the origin parameters.
sync_model_param(model, parallel_mode=ParallelMode.DATA)
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
# the same across tensor parallelism.
sync_model_param_within_tp(model)
return model
def get_train_data_loader(num_worker: int = 0):
"""
Generate and return the training data loader.
Returns: A tuple of (train_dl, dataset_types).
"""
# Get the dataset types
dataset_types = None
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
data_cfg = gpc.config.data
# Get the sample weight dictionary
train_folder = data_cfg.train_folder
if not train_folder:
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = get_packed_dataset_without_short_length(
folder=data_cfg.train_folder,
packed_length=data_cfg.packed_length,
max_length_per_sample=data_cfg.seq_len,
show_progress=dist.get_rank() == 0,
min_length=data_cfg.min_length,
min_length_dict=data_cfg.get("min_length_dict", {}),
pack_into_one_sample=data_cfg.pack_sample_into_one,
)
# partition already completed
# assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen))
if isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)):
datasets = [train_ds]
else:
datasets = train_ds.datasets
# Create the training dataset sampler
train_sampler = StaticBatchSampler(
datasets,
batch_size=data_cfg.micro_num,
rampup_batch_size=data_cfg.rampup_batch_size,
micro_bsz=data_cfg.micro_bsz,
seed=1024,
drop_last=True,
data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
# Create the training data loader
train_dl = DataLoader(
dataset=train_ds,
batch_sampler=train_sampler,
num_workers=num_worker,
pin_memory=True,
collate_fn=train_collate_fn,
persistent_workers=True,
)
return train_dl, dataset_types
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
"""
Load and return the new batch data based on training data loader.
Args:
train_dl (torch.utils.data.DataLoader): Dataloader for training.
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
train_state (TrainState): Current training state.
Returns: A batch data and the updated train_iter.
"""
timer("batch-gen").start()
try:
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
next(train_state.batch_sampler_iter)
except StopIteration:
train_iter = iter(train_dl)
batch = next(train_iter)
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
next(train_state.batch_sampler_iter)
train_state.num_consumed_samples_in_epoch = 0
timer("batch-gen").stop()
batch[0].pop("type_ids", None)
return batch, train_iter
def initialize_optimizer(model: nn.Module):
"""
Initialize optimizer.
Args:
model (torch.nn.Module): Your model instance to be trained or evaluated.
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
"""
adam_cfg = gpc.config.adam
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps,
)
optimizer = HybridZeroOptimizer(
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
return optimizer, beta2_scheduler, lr_scheduler
def record_current_batch_training_metrics(
get_tflops_func,
logger,
success_update,
batch_count,
batch,
train_state,
optimizer,
beta2_scheduler,
trainer,
start_time,
loss,
grad_norm,
):
"""
Print some training metrics of current batch.
"""
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if success_update and gpc.is_rank_for_log():
lr = optimizer.param_groups[0]["lr"]
if hasattr(trainer.engine.optimizer, "grad_scaler"):
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
num_tokens_in_batch = batch[1].nelement()
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
tk_per_gpu = 0
tk_per_gpu = round(
num_tokens_in_batch
* gpc.get_world_size(ParallelMode.DATA)
/ gpc.get_world_size(ParallelMode.GLOBAL)
/ (time.time() - start_time),
2,
)
tflops = get_tflops_func((time.time() - start_time))
infos = {
"tflops": tflops,
"step": batch_count,
"loss": loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr,
"loss_scale": scaler,
"grad_norm": grad_norm,
}
infos["micro_num"] = len(batch[1])
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
infos["largest_length"] = max_length_in_batch # the longest input
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
infos["smallest_batch"] = min_samples_in_batch
infos["adam_beta2"] = beta2_scheduler.get_beta2()
line = ""
for k, v in infos.items():
line += f"{k}={v},"
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
line += f"fwd_bwd_time={fwd_bwd_time}"
logger.info(line)
return uniscale_logger
def main(args):
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# init setting
skip_batches = gpc.config.data.skip_batches
total_steps = gpc.config.data.total_steps
load_optimizer = gpc.config.ckpt.load_optimizer
valid_every = gpc.config.data.valid_every
label_smoothing = gpc.config.loss.label_smoothing
lr = gpc.config.adam.lr
# ckpt setting
save_ckpt_folder = gpc.config.ckpt.save_ckpt_folder
enable_save_ckpt = gpc.config.ckpt.enable_ckpt
checkpoint_every = gpc.config.ckpt.checkpoint_every
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
get_tflops_func = partial(
get_megatron_flops,
checkpoint=gpc.config.model.checkpoint,
@ -359,25 +93,8 @@ def main(args):
dist.broadcast_object_list(objs, src=0)
current_time = objs[0]
model_load_path = None
if load_resume_ckpt_folder is not None:
logger.info(
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = load_resume_ckpt_folder
elif load_model_only_folder is not None:
logger.info(
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = load_model_only_folder
else:
logger.info(
f"===========New Run {current_time} on host:{socket.gethostname()},"
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
# initialize customed llm logger
uniscale_logger = initialize_llm_logger(start_time=current_time)
# initialize and resume train state
train_state = TrainState(gpc.config)
@ -385,32 +102,66 @@ def main(args):
# initialize model
model = initialize_model()
with open(args.config, "r") as f:
config_lines = f.readlines()
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.alert_address,
)
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
# initialize the train data loader
train_dl, _ = get_train_data_loader(num_worker=4)
# initialize the train and validation data loader
train_dl, dataset_types = get_train_data_loader(num_worker=4)
val_dls = get_validation_data_loader()
train_state.init_batch_sampler(train_dl)
# Loading model weights must be done before zero is initialized.
if model_load_path is not None:
load_model_checkpoint(folder=model_load_path, model=model)
ckpt_manager.try_load_model(current_time)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
# Loading other persistent training states.
if load_resume_ckpt_folder is not None:
# load lr scheduler states.
load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
# load training states.
load_context(load_resume_ckpt_folder, train_dl, train_state)
# load dataloader sampler states.
load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler)
# load optimzier states.
if load_optimizer:
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
# initialize customed llm writer
writer = Writer(
job_name=gpc.config.JOB_NAME,
launch_time=current_time,
file_name=get_parallel_log_file_name(),
tensorboard_folder=gpc.config.tensorboard_folder,
resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt.
step_count=train_state.step_count, # resume from ckpt.
config=config_lines,
logger=logger,
enable_tb=gpc.config.enable_tb,
)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
dataset_types=dataset_types,
)
# initialize trainer
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
]
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
@ -418,17 +169,17 @@ def main(args):
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks,
)
# initialize simple memory profiler
if args.profiling:
memory_profiler = SimpleMemoryProfiler(
model.model,
model,
optimizer.optim,
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
activation_config=build_activation_config(gpc.config.model.num_layers),
)
else:
memory_profiler = None
@ -441,89 +192,118 @@ def main(args):
# transfer the train data loader into train data iterator
train_iter = iter(train_dl)
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0:
torch.cuda.empty_cache()
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0:
torch.cuda.empty_cache()
start_time = time.time()
timer("one-batch").start()
start_time = time.time()
timer("one-batch").start()
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
# record the consumed samples in training
train_state.batch_count = batch_count
train_state.num_consumed_samples_in_epoch += len(batch[1])
if batch_skipper(batch_count): # skip this batch
if gpc.is_rank_for_log():
logger.info(f"Skip batch count:`{batch_count}`...")
timer("one-batch").stop()
continue
# record the consumed samples in training
train_state.batch_count = batch_count
train_state.num_consumed_samples_in_epoch += len(batch[1])
if batch_skipper(batch_count): # skip this batch
if gpc.is_rank_for_log():
logger.info(f"Skip batch count:`{batch_count}`...")
timer("one-batch").stop()
continue
# zero the grads of parameters
trainer.zero_grad()
# zero the grads of parameters
trainer.zero_grad()
# process data
if batch[0].get("type_ids", None) is not None:
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
timer("fwd-bwd").stop()
assert loss is not None
# do forward and backward
timer("fwd-bwd").start()
# update parameters, and returns (success_update, grad_norm)
trainer_result = trainer.step()
assert trainer_result is not None
_, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
timer("fwd-bwd").stop()
success_update, grad_norm = trainer_result
if success_update: # update parameters successfully
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
# update parameters, and returns (success_update, grad_norm)
trainer_result = trainer.step()
assert trainer_result is not None
# calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics(
get_tflops_func=get_tflops_func,
logger=logger,
success_update=success_update,
batch_count=batch_count,
batch=batch,
train_state=train_state,
optimizer=optimizer,
beta2_scheduler=beta2_scheduler,
trainer=trainer,
start_time=start_time,
loss=loss,
grad_norm=grad_norm,
)
success_update, grad_norm_groups = trainer_result
if success_update: # update parameters successfully
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address,
message=f"Warning: skip parameter update at step {batch_count}.",
)
timer("one-batch").stop()
if memory_profiler is not None:
memory_profiler.step()
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# # save batch sampler that tracks the true consumed samples
if enable_save_ckpt and train_state.step_count % checkpoint_every == 0:
save_checkpoint(
folder=save_ckpt_folder,
model=model,
optimizer=optimizer,
scheduler=lr_scheduler,
# calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics(
get_tflops_func=get_tflops_func,
logger=logger,
writer=writer,
success_update=success_update,
batch_count=batch_count,
batch=batch,
train_state=train_state,
model_config=gpc.config.model,
optimizer=optimizer,
beta2_scheduler=beta2_scheduler,
trainer=trainer,
start_time=start_time,
loss=loss,
grad_norm=np.array(grad_norm_groups),
metric=metric,
update_panel=uniscale_logger is not None,
)
# wait for all checkpoint uploads to be completed
dist.barrier()
timer("one-batch").stop()
# evaluate on validation data loaders
if valid_every > 0 and train_state.step_count % valid_every == 0:
evaluate_on_val_dls(
trainer=trainer,
val_dls=val_dls,
writer=writer,
logger=logger,
step_count=train_state.step_count,
update_panel=uniscale_logger is not None,
)
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# # save batch sampler that tracks the true consumed samples
now_break = ckpt_manager.try_save_checkpoint(train_state)
if now_break:
break
if memory_profiler is not None:
memory_profiler.step()
if batch_count % 2 == 0:
prof.step()
ckpt_manager.wait_async_upload_finish()
if __name__ == "__main__":
args = parse_args()
hostname = socket.gethostname()
try:
main(args)
except Exception:
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
traceback.print_exc()
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# initialize monitor manager context
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
try:
main(args)
except Exception:
logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
)
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())

View File

@ -1,23 +1,20 @@
"""
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers. We mainly modified part of the code logic to adapt to the generation of our model.
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
We mainly modified part of the code logic to adapt to the generation of our model.
Please refer to these links below for more information:
1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
3. transformers: https://github.com/huggingface/transformers
"""
from dataclasses import asdict
import streamlit as st
import torch
from dataclasses import dataclass, asdict
from typing import List, Optional, Callable, Optional
import copy
import warnings
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from tools.transformers.interface import generate_interactive, GenerationConfig
from tools.transformers.interface import GenerationConfig, generate_interactive
logger = logging.get_logger(__name__)
@ -25,9 +22,14 @@ logger = logging.get_logger(__name__)
def on_btn_click():
del st.session_state.messages
@st.cache_resource
def load_model():
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()
model = (
AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
return model, tokenizer
@ -35,20 +37,12 @@ def load_model():
def prepare_generation_config():
with st.sidebar:
max_length = st.slider("Max Length", min_value=32, max_value=2048, value=2048)
top_p = st.slider(
'Top P', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.slider(
'Temperature', 0.0, 1.0, 0.7, step=0.01
)
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
st.button("Clear Chat History", on_click=on_btn_click)
generation_config = GenerationConfig(
max_length=max_length,
top_p=top_p,
temperature=temperature
)
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
return generation_config
@ -74,16 +68,16 @@ def combine_history(prompt):
def main():
#torch.cuda.empty_cache()
# torch.cuda.empty_cache()
print("load model begin.")
model, tokenizer = load_model()
print("load model end.")
user_avator = "doc/imgs/user.png"
robot_avator = "doc/imgs/robot.png"
st.title("InternLM-Chat-7B")
generation_config = prepare_generation_config()
# Initialize chat history
@ -106,22 +100,20 @@ def main():
with st.chat_message("robot", avatar=robot_avator):
message_placeholder = st.empty()
for cur_response in generate_interactive(model=model, tokenizer=tokenizer, prompt=real_prompt, additional_eos_token_id=103028, **asdict(generation_config)):
for cur_response in generate_interactive(
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=103028,
**asdict(generation_config),
):
# Display robot response in chat message container
message_placeholder.markdown(cur_response + "")
message_placeholder.markdown(cur_response)
# Add robot response to chat history
st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
torch.cuda.empty_cache()
if __name__ == "__main__":
main()