Merge develop to main (#314)

* feat: add unitest for model (#300)

* feat: add unitest for model

* feat:add model test

* Merge main to develop (#309)

* fix(chat): fix stream_chat to return generator (#123)

* fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302)

* fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299)

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com>
Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>

* docs(doc/code-docs): add figure for training docs (#307)

* add training image for docs

* docs(doc/code-docs): add training img for en doc

* docs(doc/code-docs): fix en docs for initialize

* docs(doc/code-docs): update conf file for readthedocs

* docs(doc/code-docs): fix typos

* docs(doc/code-docs): fix typos for reathedocs

* docs(doc/code-docs): minor typo fix for reathedocs

* docs(doc/code-docs): fix readthedocs conf file

* docs(doc/code-docs): update training image

* docs(doc/code-docs): fix typos

* docs(doc/code-docs): update training image

* docs(doc/code-docs): move training image to section initialize

* docs(doc/code-docs): fix lint

* add badge about reathedocs status

* Merge main to develop (#312)

* fix(chat): fix stream_chat to return generator (#123)

* fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302)

* fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299)

* docs(doc/code-docs): update quickstart usage (#301)

* docs(usage.md): update usage.md

* docs(doc/code-docs): update en usage

---------

Co-authored-by: huangting4201 <huangting3@sensetime.com>

* docs(doc/code-docs): update en usage

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com>
Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>
Co-authored-by: huangting4201 <huangting3@sensetime.com>

* feat: more tgs (#310)

* feat:more tgs

* feat:add more tgs

* feat:more tgs

* feat: add optimizer_unitest (#303)

* feat: add optimizer_unitest

* feat: add optimizer test

* feat: add optimizer test

* feat:add optimizer test

* fianl change

* feat:add optimizer test

* feat:add optimizer test

* feat:add optimizer test

---------

Co-authored-by: jiaxingli <43110891+li126com@users.noreply.github.com>
Co-authored-by: yingtongxiong <974106207@qq.com>
Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com>
Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>
Co-authored-by: Season <caizheng@pjlab.org.cn>
Co-authored-by: huangting4201 <huangting3@sensetime.com>
pull/322/head v0.2.1dev20230915
huangting4201 2023-09-15 19:12:38 +08:00 committed by GitHub
parent 42802a2b31
commit 2710fa7343
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1150 additions and 79 deletions

View File

@ -16,6 +16,7 @@
[![license](./doc/imgs/license.svg)](./LICENSE) [![license](./doc/imgs/license.svg)](./LICENSE)
[![evaluation](./doc/imgs/compass_support.svg)](https://github.com/internLM/OpenCompass/) [![evaluation](./doc/imgs/compass_support.svg)](https://github.com/internLM/OpenCompass/)
[![Documentation Status](https://readthedocs.org/projects/internlm/badge/?version=latest)](https://internlm.readthedocs.io/zh_CN/latest/?badge=latest)
[📘使用法](./doc/en/usage.md) | [📘使用法](./doc/en/usage.md) |
[🛠️インストール](./doc/en/install.md) | [🛠️インストール](./doc/en/install.md) |

View File

@ -16,6 +16,7 @@
[![license](./doc/imgs/license.svg)](https://github.com/open-mmlab/mmdetection/blob/main/LICENSE) [![license](./doc/imgs/license.svg)](https://github.com/open-mmlab/mmdetection/blob/main/LICENSE)
[![evaluation](./doc/imgs/compass_support.svg)](https://github.com/internLM/OpenCompass/) [![evaluation](./doc/imgs/compass_support.svg)](https://github.com/internLM/OpenCompass/)
[![Documentation Status](https://readthedocs.org/projects/internlm/badge/?version=latest)](https://internlm.readthedocs.io/zh_CN/latest/?badge=latest)
[📘使用文档](./doc/usage.md) | [📘使用文档](./doc/usage.md) |
[🛠️安装教程](./doc/install.md) | [🛠️安装教程](./doc/install.md) |

View File

@ -16,6 +16,7 @@
[![license](./doc/imgs/license.svg)](./LICENSE) [![license](./doc/imgs/license.svg)](./LICENSE)
[![evaluation](./doc/imgs/compass_support.svg)](https://github.com/internLM/OpenCompass/) [![evaluation](./doc/imgs/compass_support.svg)](https://github.com/internLM/OpenCompass/)
[![Documentation Status](https://readthedocs.org/projects/internlm/badge/?version=latest)](https://internlm.readthedocs.io/zh_CN/latest/?badge=latest)
[📘Usage](./doc/en/usage.md) | [📘Usage](./doc/en/usage.md) |
[🛠Installation](./doc/en/install.md) | [🛠Installation](./doc/en/install.md) |

View File

@ -8,7 +8,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: InternLM \n" "Project-Id-Version: InternLM \n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-11 14:25+0800\n" "POT-Creation-Date: 2023-09-13 17:07+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: en\n" "Language: en\n"
@ -26,16 +26,19 @@ msgstr "Model Checkpointing"
#: ../../source/checkpoint.rst:4 #: ../../source/checkpoint.rst:4
msgid "" msgid ""
"InternLM 使用 ``internlm.utils.model_checkpoint.CheckpointManager`` " "InternLM 使用 ``internlm.utils.model_checkpoint.CheckpointManager`` "
"来管理模型保存。 其中,可以 使用 ``CheckpointManager.try_save_checkpoint(train_state)`` " "来管理模型保存。其中,可以使用 ``CheckpointManager.try_save_checkpoint(train_state)`` "
"来保存指定 step 的模型状态。InternLM支持启动时自动加载最新的模型备份并在接收信号退出训练时自动进行模型备份。" "来保存指定 step 的模型状态。"
msgstr "" msgstr ""
"InternLM uses ``internlm.utils.model_checkpoint.CheckpointManager`` to " "InternLM uses ``internlm.utils.model_checkpoint.CheckpointManager`` to "
"manage model checkpointing. In the implementation, we use " "manage model checkpointing. In the implementation, we use "
"``CheckpointManager.try_save_checkpoint(train_state)`` to checkpoint " "``CheckpointManager.try_save_checkpoint(train_state)`` to checkpoint "
"training states at specific steps. InternLM supports automatic loading of" "training states at specific steps. "
" latest ckpt at startup and automatic model checkpointing at signal quit."
#: ../../source/checkpoint.rst:8 #: ../../source/checkpoint.rst:6
msgid "InternLM支持启动时自动加载最新的模型备份并在接收信号退出训练时自动进行模型备份。"
msgstr "InternLM supports automatic loading of latest ckpt at startup and automatic model checkpointing at signal quit. "
#: ../../source/checkpoint.rst:9
msgid "Checkpointing" msgid "Checkpointing"
msgstr "" msgstr ""

View File

@ -37,8 +37,8 @@ msgstr "Start Training"
#: ../../source/example/30B_demo.rst:166 24974384d5ab42e68266aeb67ae222ce #: ../../source/example/30B_demo.rst:166 24974384d5ab42e68266aeb67ae222ce
msgid "完成以上训练配置后,可启动模型训练,以在 ``slurm`` 平台上为例,启动两节点 16GPU 的训练命令如下所示:" msgid "完成以上训练配置后,可启动模型训练,以在 ``slurm`` 平台上为例,启动两节点 16GPU 的训练命令如下所示:"
msgstr "After completing the data preparation and relevant training configurations, you can start the demo training. msgstr "After completing the data preparation and relevant training configurations, you can start the demo training. "
The following example shows how to start distributed training in ``slurm`` environments with 16 GPUs." "The following example shows how to start distributed training in ``slurm`` environments with 16 GPUs."
#: ../../source/example/30B_demo.rst:173 948ac71ed53848f9bad07f69d956c4bb #: ../../source/example/30B_demo.rst:173 948ac71ed53848f9bad07f69d956c4bb
msgid "训练结果" msgid "训练结果"

View File

@ -37,8 +37,8 @@ msgstr "Start Training"
#: ../../source/example/7B_demo.rst:164 9e7a864ae2e14d05b0681f16792e5278 #: ../../source/example/7B_demo.rst:164 9e7a864ae2e14d05b0681f16792e5278
msgid "完成以上训练配置后,可启动模型训练,以在 ``slurm`` 平台上为例,启动单节点 8GPU 的训练命令如下所示:" msgid "完成以上训练配置后,可启动模型训练,以在 ``slurm`` 平台上为例,启动单节点 8GPU 的训练命令如下所示:"
msgstr "After completing the data preparation and relevant training configurations, you can start the demo training. msgstr "After completing the data preparation and relevant training configurations, you can start the demo training. "
The following example shows how to start distributed training in ``slurm`` environments with 8 GPUs." "The following example shows how to start distributed training in ``slurm`` environments with 8 GPUs."
#: ../../source/example/7B_demo.rst:171 fdd053efb1854d46aabf6c0f279fe7fc #: ../../source/example/7B_demo.rst:171 fdd053efb1854d46aabf6c0f279fe7fc
msgid "训练结果" msgid "训练结果"

View File

@ -8,7 +8,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: InternLM \n" "Project-Id-Version: InternLM \n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-08 15:32+0800\n" "POT-Creation-Date: 2023-09-14 12:23+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n" "Language: zh_CN\n"
@ -23,24 +23,68 @@ msgstr ""
msgid "训练构建" msgid "训练构建"
msgstr "Training Setup" msgstr "Training Setup"
#: ../../source/initialize.rst:7 #: ../../source/initialize.rst:4
msgid "InternLM 的训练流程可以归纳为两个步骤:"
msgstr "The training process of InternLM can be summarized into two steps: "
#: ../../source/initialize.rst:6
msgid "初始化"
msgstr "Initialization"
#: ../../source/initialize.rst:8
msgid "初始化模型、优化器、数据加载器、Trainer生成不同种类的进程组为混合并行的迭代训练做准备。"
msgstr ""
"Initialize model, optimizer, dataloader, trainer, and create different "
"types of process groups to prepare for iterative steps of hybrid parallel training. "
#: ../../source/initialize.rst:9
msgid "初始化Logger、Checkpoint管理器、Monitor管理器、Profiler对迭代训练的过程观察、预警、记录。"
msgstr ""
"Initialize logger, checkpoint manager, monitor manager, and profiler to "
"watch, alert, and record the iterative training steps. "
#: ../../source/initialize.rst:11
msgid "迭代训练"
msgstr "Iterative training steps"
#: ../../source/initialize.rst:13
msgid "根据配置文件定义的张量并行、流水线并行、数据并行的大小,加载训练引擎和调度器进行混合并行训练。"
msgstr ""
"Load the training engine and scheduler for hybrid parallel training "
"according to the configuration such as tensor parallel size, pipeline "
"parallel size, and data parallel size. "
#: ../../source/initialize.rst:14
msgid "在迭代训练中,调用 Trainer API 进行梯度置零,前向传播计算损失并反向传播,参数更新。"
msgstr ""
"In iterative training steps, the Trainer API is called to perform zero "
"gradients, forward-loss-backward, and parameter update."
#: ../../source/initialize.rst:20
msgid "InternLM训练流程图"
msgstr "InternLM training process"
#: ../../source/initialize.rst:25
msgid "命令行参数解析" msgid "命令行参数解析"
msgstr "Argument Parsing" msgstr "Argument Parsing"
#: ../../source/initialize.rst:9 #: ../../source/initialize.rst:27
#, fuzzy
msgid "" msgid ""
"InternLM 使用 `argparse <https://docs.python.org/3/library/argparse.html>`_" "InternLM 使用 `argparse <https://docs.python.org/3/library/argparse.html>`_"
" 库来向InternLM运行时提供命令行参数配置。用户可使用 " " 库来向InternLM运行时提供命令行参数配置。"
"``internlm.initialize.get_default_parser()`` 来获取 InternLM "
"的默认解析器,其中包含一些内置参数,用户可以向此解析器添加自定义参数。"
msgstr "" msgstr ""
"InternLM uses the `argparse " "InternLM uses the `argparse "
"<https://docs.python.org/3/library/argparse.html>`_ library to supply " "<https://docs.python.org/3/library/argparse.html>`_ library to supply "
"commandline configuration to the InternLM runtime. Use " "commandline configuration to the InternLM runtime. "
"``internlm.initialize.get_default_parser()`` to get InternLM's default "
"parser with some builtin arguments, users can add custom parameters to " #: ../../source/initialize.rst:29
"this parser." msgid ""
"用户可使用 ``internlm.initialize.get_default_parser()`` 来获取 InternLM "
"的默认解析器,其中包含一些内置参数,用户可以向此解析器添加自定义参数。"
msgstr ""
"Use ``internlm.initialize.get_default_parser()`` to get InternLM's "
"default parser with some builtin arguments, users can add custom "
"parameters to this parser."
#: internlm.initialize.launch.get_default_parser:1 of #: internlm.initialize.launch.get_default_parser:1 of
msgid "" msgid ""
@ -69,7 +113,7 @@ msgstr ""
msgid "返回类型" msgid "返回类型"
msgstr "" msgstr ""
#: ../../source/initialize.rst:25 #: ../../source/initialize.rst:45
msgid "模型初始化" msgid "模型初始化"
msgstr "Model Initialization" msgstr "Model Initialization"
@ -81,26 +125,26 @@ msgstr ""
msgid "The neural network model to be trained or evaluated." msgid "The neural network model to be trained or evaluated."
msgstr "" msgstr ""
#: ../../source/initialize.rst:29 #: ../../source/initialize.rst:49
msgid "InternLM 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下:" msgid "InternLM 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下:"
msgstr "" msgstr ""
"InternLM uses the field ``model_type`` and ``model`` in the config file " "InternLM uses the field ``model_type`` and ``model`` in the config file "
"to control model initialization process. An example model initialization " "to control model initialization process. An example model initialization "
"configuratio" "configuratio"
#: ../../source/initialize.rst:57 #: ../../source/initialize.rst:77
msgid "字段 ``model_type`` 指明了要初始化的模型类型" msgid "字段 ``model_type`` 指明了要初始化的模型类型"
msgstr "" msgstr ""
"The field ``model_type`` specifics the model type has been registered and" "The field ``model_type`` specifics the model type has been registered and"
" to be initialized." " to be initialized."
#: ../../source/initialize.rst:58 #: ../../source/initialize.rst:78
msgid "字段 ``model`` 中的参数指定了在模型初始化过程中的参数设置" msgid "字段 ``model`` 中的参数指定了在模型初始化过程中的参数设置"
msgstr "" msgstr ""
"The parameters in field ``model`` specific the configuration settings " "The parameters in field ``model`` specific the configuration settings "
"during model initialization." "during model initialization."
#: ../../source/initialize.rst:60 #: ../../source/initialize.rst:80
msgid "" msgid ""
"值得注意的是,用户可以定义新的模型类型,并使用装饰器 ``@MODEL_INITIALIZER.register_module`` " "值得注意的是,用户可以定义新的模型类型,并使用装饰器 ``@MODEL_INITIALIZER.register_module`` "
"注册模型的初始化函数,其中 ``MODEL_INITIALIZER`` 是类 " "注册模型的初始化函数,其中 ``MODEL_INITIALIZER`` 是类 "
@ -112,7 +156,7 @@ msgstr ""
" instantiated object of class ``internlm.util.registry.Registry``, the " " instantiated object of class ``internlm.util.registry.Registry``, the "
"example is shown as follows." "example is shown as follows."
#: ../../source/initialize.rst:72 #: ../../source/initialize.rst:92
msgid "优化器初始化" msgid "优化器初始化"
msgstr "Optimizer Initialization" msgstr "Optimizer Initialization"
@ -134,7 +178,7 @@ msgstr ""
msgid "A tuple of (optimizer, beta2_scheduler, lr_scheduler)." msgid "A tuple of (optimizer, beta2_scheduler, lr_scheduler)."
msgstr "" msgstr ""
#: ../../source/initialize.rst:79 #: ../../source/initialize.rst:99
msgid "数据加载器初始化" msgid "数据加载器初始化"
msgstr "Dataloader Initialization" msgstr "Dataloader Initialization"
@ -162,7 +206,7 @@ msgstr ""
msgid "A tuple of (train_dl, dataset_types)." msgid "A tuple of (train_dl, dataset_types)."
msgstr "" msgstr ""
#: ../../source/initialize.rst:86 #: ../../source/initialize.rst:106
msgid "Trainer 初始化" msgid "Trainer 初始化"
msgstr "Trainer Initialization" msgstr "Trainer Initialization"

View File

@ -8,7 +8,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: InternLM \n" "Project-Id-Version: InternLM \n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-08 15:32+0800\n" "POT-Creation-Date: 2023-09-14 11:05+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: en\n" "Language: en\n"
@ -32,13 +32,13 @@ msgid ""
"InternLM 使用 ``internlm.train.initialize_llm_profile()`` " "InternLM 使用 ``internlm.train.initialize_llm_profile()`` "
"来收集和分析模型训练或推理期间的性能数据,如 CPU/CUDA/memory 等性能数据。这个实现基于 `torch.profiler " "来收集和分析模型训练或推理期间的性能数据,如 CPU/CUDA/memory 等性能数据。这个实现基于 `torch.profiler "
"<https://pytorch.org/docs/stable/profiler.html>`_ ,输出的性能分析 trace 文件可以使用 " "<https://pytorch.org/docs/stable/profiler.html>`_ ,输出的性能分析 trace 文件可以使用 "
"`tensorboard <https://www.tensorflow.org>`_ 进行可视化。" "`tensorboard <https://www.tensorflow.org/tensorboard?hl=en>`_ 进行可视化。"
msgstr "" msgstr ""
"InternLM uses ``internlm.train.initialize_llm_profile()`` to profile " "InternLM uses ``internlm.train.initialize_llm_profile()`` to profile "
"performance data, execution time duration and breakdown analysis of step " "performance data, execution time duration and breakdown analysis of step "
"time. The implementation is based on `torch.profiler " "time. The implementation is based on `torch.profiler "
"<https://pytorch.org/docs/stable/profiler.html>`_ and output tracing " "<https://pytorch.org/docs/stable/profiler.html>`_ and output tracing "
"files can be visualized with `tensorboard <https://www.tensorflow.org>`_." "files can be visualized with `tensorboard <https://www.tensorflow.org/tensorboard?hl=en>`_."
#: ../../source/profiler.rst:11 #: ../../source/profiler.rst:11
msgid "" msgid ""
@ -53,11 +53,15 @@ msgstr ""
#: ../../source/profiler.rst:13 #: ../../source/profiler.rst:13
msgid "实际运行生成的 ``Torch Profiler`` 目录结构如下:" msgid "实际运行生成的 ``Torch Profiler`` 目录结构如下:"
msgstr "The directory structure of ``Torch Profiler`` generated files is as follows:" msgstr ""
"The directory structure of ``Torch Profiler`` generated files is as "
"follows:"
#: ../../source/profiler.rst:22 #: ../../source/profiler.rst:22
msgid "其中, ``traces`` 可以通过 ``TensorBoard`` 可视化,运行命令" msgid "其中, ``traces`` 可以通过 ``TensorBoard`` 可视化,运行命令"
msgstr "Among them, ``traces`` can be visualized through ``TensorBoard`` and run with the command" msgstr ""
"Among them, ``traces`` can be visualized through ``TensorBoard`` and run "
"with the command"
#: ../../source/profiler.rst:29 #: ../../source/profiler.rst:29
msgid "" msgid ""
@ -66,7 +70,12 @@ msgid ""
"tensorboard " "tensorboard "
"<https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html" "<https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html"
"#pytorch-profiler-with-tensorboard>`_" "#pytorch-profiler-with-tensorboard>`_"
msgstr "In the opened ``TensorBoard -> PyTorch Profiler -> Views -> Trace`` page, you can see the timeline of profiled operators and GPU kernels. For more usage, please refer to `torch profiler with tensorboard <https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html#pytorch-profiler-with-tensorboard>`_" msgstr ""
"In the opened ``TensorBoard -> PyTorch Profiler -> Views -> Trace`` page,"
" you can see the timeline of profiled operators and GPU kernels. For more"
" usage, please refer to `torch profiler with tensorboard "
"<https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html"
"#pytorch-profiler-with-tensorboard>`_"
#: internlm.train.training_internlm.initialize_llm_profile:1 of #: internlm.train.training_internlm.initialize_llm_profile:1 of
msgid "Initialize and return the profiler context manager instance." msgid "Initialize and return the profiler context manager instance."

View File

@ -8,7 +8,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: InternLM \n" "Project-Id-Version: InternLM \n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-07 10:56+0800\n" "POT-Creation-Date: 2023-09-14 12:23+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: en\n" "Language: en\n"
@ -19,109 +19,144 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n" "Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n" "Generated-By: Babel 2.12.1\n"
#: ../../source/training.rst:2 6eafa5eb08e040039309a39cdb0f1bfe #: ../../source/training.rst:2
msgid "训练 API" msgid "训练 API"
msgstr "Training API" msgstr "Training API"
#: ../../source/training.rst:4 74d81f3d0ca54c839d4e80bd589aedb2 #: ../../source/training.rst:4
msgid "" msgid ""
"InternLM 的训练 API 由 ``internlm.core.trainer.Trainer`` " "InternLM 的训练 API 由 ``internlm.core.trainer.Trainer`` "
"管理。在定义了训练引擎和调度器之后,我们可以调用 Trainer API 来执行模型训练、评估、梯度清零和参数更新等。" "管理。在定义了训练引擎和调度器之后,我们可以调用 Trainer API 来执行模型训练、评估、梯度清零和参数更新等。"
msgstr "" msgstr ""
"InternLM training API is managed in ``internlm.core.trainer.Trainer``. After defining the " "InternLM training API is managed in ``internlm.core.trainer.Trainer``. "
"training engine and runtime scheduler, we can call training API to perform training, evaluation, " "After defining the training engine and runtime scheduler, we can call "
"zero gradients and parameter update steps." "training API to perform training, evaluation, zero gradients and "
"parameter update steps."
#: ../../source/training.rst:6 0e0cfddbb2334d3da99d3289edf4161d #: ../../source/training.rst:6
msgid "有关详细用法,请参阅 Trainer API 文档和示例。" msgid "有关详细用法,请参阅 Trainer API 文档和示例。"
msgstr "For detailed usage, please refer to Trainer API documentation and examples." msgstr ""
"For detailed usage, please refer to Trainer API documentation and "
"examples."
#: 7ea10280a8f1489984cb9994aa08976b internlm.core.trainer.Trainer:1 of #: internlm.core.trainer.Trainer:1 of
msgid "" msgid ""
"This is a class tending for easy deployments of users' training and " "This is a class tending for easy deployments of users' training and "
"evaluation instead of writing their own scripts." "evaluation instead of writing their own scripts."
msgstr "" msgstr ""
#: 7969dca55840451193bffd3b071ab3b3 aff576168b59460491bb5da0ce41ea74
#: internlm.core.trainer.Trainer internlm.core.trainer.Trainer.execute_schedule #: internlm.core.trainer.Trainer internlm.core.trainer.Trainer.execute_schedule
#: of #: of
msgid "参数" msgid "参数"
msgstr "" msgstr ""
#: 59754d3e9ee8452a872bf397c01e0d8c internlm.core.trainer.Trainer:4 of #: internlm.core.trainer.Trainer:4 of
msgid "Engine responsible for the process function." msgid "Engine responsible for the process function."
msgstr "" msgstr ""
#: 2d18ff15256e48f98901c7a7e0cbbe35 internlm.core.trainer.Trainer:6 of #: internlm.core.trainer.Trainer:6 of
msgid "Runtime schedule. Defaults to None." msgid "Runtime schedule. Defaults to None."
msgstr "" msgstr ""
#: 76f4b3c7feba40eca3ee2b32559c53f5 internlm.core.trainer.Trainer.engine:1 of #: internlm.core.trainer.Trainer.engine:1 of
msgid "" msgid ""
"Returns the engine that responsible for managing the training and " "Returns the engine that responsible for managing the training and "
"evaluation process." "evaluation process."
msgstr "" msgstr ""
#: c7eae2d4d06c4ef891e314902d80b7f3 internlm.core.trainer.Trainer.schedule:1 of #: internlm.core.trainer.Trainer.schedule:1 of
msgid "Returns the runtime scheduler." msgid "Returns the runtime scheduler."
msgstr "" msgstr ""
#: cb495b21b3444881aec83803e92386d9
#: internlm.core.trainer.Trainer.uses_pipeline:1 of #: internlm.core.trainer.Trainer.uses_pipeline:1 of
msgid "Returns whether the pipeline parallel is used or not." msgid "Returns whether the pipeline parallel is used or not."
msgstr "" msgstr ""
#: 86b0b631189e46468281a397c5e97350 internlm.core.trainer.Trainer.train:1 of #: internlm.core.trainer.Trainer.train:1 of
msgid "Sets the model to training mode." msgid "Sets the model to training mode."
msgstr "" msgstr ""
#: f997e13120ee4d8b9e45ea6698b3e2a6 internlm.core.trainer.Trainer.eval:1 of #: internlm.core.trainer.Trainer.eval:1 of
msgid "Sets the model to evaluation mode." msgid "Sets the model to evaluation mode."
msgstr "" msgstr ""
#: a8179e50312d47dcbe9de0433a65c2f7 internlm.core.trainer.Trainer.zero_grad:1 #: internlm.core.trainer.Trainer.zero_grad:1 of
#: of
msgid "Sets the gradient of all parameters in the model to zero." msgid "Sets the gradient of all parameters in the model to zero."
msgstr "" msgstr ""
#: f936136ef9e0452ca439b7c66dc8884b internlm.core.trainer.Trainer.step:1 of #: internlm.core.trainer.Trainer.step:1 of
msgid "Executes the parameter update step." msgid "Executes the parameter update step."
msgstr "" msgstr ""
#: 250e2af89cfd432c84d228f9e03c174c
#: internlm.core.trainer.Trainer.execute_schedule:1 of #: internlm.core.trainer.Trainer.execute_schedule:1 of
msgid "" msgid ""
"Runs the forward, loss computation, and backward for the model. Returns a" "Runs the forward, loss computation, and backward for the model. Returns a"
" tuple of (output, label, loss)." " tuple of (output, label, loss)."
msgstr "" msgstr ""
#: 6ca7de83033b432792eb0d7935ea04da
#: internlm.core.trainer.Trainer.execute_schedule:4 of #: internlm.core.trainer.Trainer.execute_schedule:4 of
msgid "The data iterator." msgid "The data iterator."
msgstr "" msgstr ""
#: 6d3044e75b3149beba3c659e15607b79
#: internlm.core.trainer.Trainer.execute_schedule:6 of #: internlm.core.trainer.Trainer.execute_schedule:6 of
msgid "Additional keyword arguments." msgid "Additional keyword arguments."
msgstr "" msgstr ""
#: 99d5a297d6414c30b432acf2566f0d3c
#: internlm.core.trainer.Trainer.execute_schedule of #: internlm.core.trainer.Trainer.execute_schedule of
msgid "返回" msgid "返回"
msgstr "" msgstr ""
#: b625ebf0cf874edba384456d33e740b4
#: internlm.core.trainer.Trainer.execute_schedule:8 of #: internlm.core.trainer.Trainer.execute_schedule:8 of
msgid "A tuple of (output, label, loss)." msgid "A tuple of (output, label, loss)."
msgstr "" msgstr ""
#: 391cde57d2e2478d8f83a7ad270c2a65
#: internlm.core.trainer.Trainer.execute_schedule of #: internlm.core.trainer.Trainer.execute_schedule of
msgid "返回类型" msgid "返回类型"
msgstr "" msgstr ""
#: d4c4fb0fbddb499786970509cf0c9e13
#: internlm.core.trainer.Trainer.execute_schedule:9 of #: internlm.core.trainer.Trainer.execute_schedule:9 of
msgid "Tuple[:class:`torch.Tensor`]" msgid "Tuple[:class:`torch.Tensor`]"
msgstr "" msgstr ""
#~ msgid "InternLM 的训练流程可以归纳为两个步骤:"
#~ msgstr "The training process of InternLM can be summarized into two steps: "
#~ msgid "初始化"
#~ msgstr "Initialization"
#~ msgid "初始化模型、优化器、数据加载器、Trainer生成不同种类的进程组为混合并行的迭代训练做准备。"
#~ msgstr ""
#~ "Initialize model, optimizer, dataloader, "
#~ "trainer, and create different types of"
#~ " process groups to prepare for "
#~ "iterative steps of hybrid parallel "
#~ "training. "
#~ msgid "初始化Logger、Checkpoint管理器、Monitor管理器、Profiler对迭代训练的过程观察、预警、记录。"
#~ msgstr ""
#~ "Initialize logger, checkpoint manager, monitor"
#~ " manager, and profiler to watch, "
#~ "alert, and record the iterative training"
#~ " steps. "
#~ msgid "迭代训练"
#~ msgstr "Iterative training steps"
#~ msgid "根据配置文件定义的张量并行、流水线并行、数据并行的大小,加载训练引擎和调度器进行混合并行训练。"
#~ msgstr ""
#~ "Load the training engine and scheduler"
#~ " for hybrid parallel training according "
#~ "to the configuration such as tensor "
#~ "parallel size, pipeline parallel size, "
#~ "and data parallel size. "
#~ msgid "在迭代训练中,调用 Trainer API 进行梯度置零,前向传播计算损失并反向传播,参数更新。"
#~ msgstr ""
#~ "In iterative training steps, the Trainer"
#~ " API is called to perform zero "
#~ "gradients, forward-loss-backward, and "
#~ "parameter update."
#~ msgid "InternLM训练流程图"
#~ msgstr "InternLM training process"

View File

@ -183,7 +183,8 @@ msgstr ""
#: ../../../usage.md:237 #: ../../../usage.md:237
msgid "接下来将详细介绍启动一个模型训练所需要进行的数据、模型、并行和监控等相关的配置。" msgid "接下来将详细介绍启动一个模型训练所需要进行的数据、模型、并行和监控等相关的配置。"
msgstr "let's discuss the data, model, parallel and monitoring configurations " msgstr ""
"let's discuss the data, model, parallel and monitoring configurations "
"required to start a model training." "required to start a model training."
#: ../../../usage.md:239 #: ../../../usage.md:239
@ -275,7 +276,6 @@ msgstr ""
"default value is -1" "default value is -1"
#: ../../../usage.md:325 #: ../../../usage.md:325
#, fuzzy
msgid "当`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配" msgid "当`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配"
msgstr "" msgstr ""
"When `zero1 <= 0`, the size of the zero1 process group is equal to the " "When `zero1 <= 0`, the size of the zero1 process group is equal to the "
@ -283,14 +283,12 @@ msgstr ""
"parameters will be split within the data parallel range." "parameters will be split within the data parallel range."
#: ../../../usage.md:326 #: ../../../usage.md:326
#, fuzzy
msgid "当`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数" msgid "当`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数"
msgstr "" msgstr ""
"When `zero1 == 1`, zero1 is not used, and all data parallel groups retain" "When `zero1 == 1`, zero1 is not used, and all data parallel groups retain"
" the complete optimizer state parameters." " the complete optimizer state parameters."
#: ../../../usage.md:327 #: ../../../usage.md:327
#, fuzzy
msgid "当`zero1 > 1`且`zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集" msgid "当`zero1 > 1`且`zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集"
msgstr "" msgstr ""
"When `zero1 > 1` and `zero1 <= data_parallel_world_size`, the zero1 " "When `zero1 > 1` and `zero1 <= data_parallel_world_size`, the zero1 "

View File

@ -1,8 +1,9 @@
模型保存 模型保存
=================== ===================
InternLM 使用 ``internlm.utils.model_checkpoint.CheckpointManager`` 来管理模型保存。 其中,可以 InternLM 使用 ``internlm.utils.model_checkpoint.CheckpointManager`` 来管理模型保存。其中,可以使用 ``CheckpointManager.try_save_checkpoint(train_state)`` 来保存指定 step 的模型状态。
使用 ``CheckpointManager.try_save_checkpoint(train_state)`` 来保存指定 step 的模型状态。InternLM支持启动时自动加载最新的模型备份并在接收信号退出训练时自动进行模型备份。
InternLM支持启动时自动加载最新的模型备份并在接收信号退出训练时自动进行模型备份。
Checkpointing Checkpointing
------------- -------------

View File

@ -72,14 +72,14 @@ exclude_patterns = []
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_rtd_theme" html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"] html_static_path = []
# GitHub integration # GitHub integration
html_context = { html_context = {
"display_github": True, "display_github": True,
"github_user": "InternLM", "github_user": "InternLM",
"github_repo": "InternLM", "github_repo": "InternLM",
"github_version": "master", "github_version": "main",
"conf_py_path": "/doc/code-docs/source/", "conf_py_path": "/doc/code-docs/source/",
} }

View File

@ -1,12 +1,32 @@
训练构建 训练构建
============== ==============
InternLM 的训练流程可以归纳为两个步骤:
1. 初始化
* 初始化模型、优化器、数据加载器、Trainer生成不同种类的进程组为混合并行的迭代训练做准备。
* 初始化Logger、Checkpoint管理器、Monitor管理器、Profiler对迭代训练的过程观察、预警、记录。
2. 迭代训练
* 根据配置文件定义的张量并行、流水线并行、数据并行的大小,加载训练引擎和调度器进行混合并行训练。
* 在迭代训练中,调用 Trainer API 进行梯度置零,前向传播计算损失并反向传播,参数更新。
.. figure:: ../../imgs/hybrid_parallel_training.png
:scale: 45%
:class: with-border
InternLM训练流程图
.. _InternLM-args: .. _InternLM-args:
命令行参数解析 命令行参数解析
---------------- ----------------
InternLM 使用 `argparse <https://docs.python.org/3/library/argparse.html>`_ 库来向InternLM运行时提供命令行参数配置。用户可使用 ``internlm.initialize.get_default_parser()`` 来获取 InternLM 的默认解析器,其中包含一些内置参数,用户可以向此解析器添加自定义参数。 InternLM 使用 `argparse <https://docs.python.org/3/library/argparse.html>`_ 库来向InternLM运行时提供命令行参数配置。
用户可使用 ``internlm.initialize.get_default_parser()`` 来获取 InternLM 的默认解析器,其中包含一些内置参数,用户可以向此解析器添加自定义参数。
.. code-block:: python .. code-block:: python

View File

@ -6,7 +6,7 @@
Torch Profiler Torch Profiler
----------------- -----------------
InternLM 使用 ``internlm.train.initialize_llm_profile()`` 来收集和分析模型训练或推理期间的性能数据,如 CPU/CUDA/memory 等性能数据。这个实现基于 `torch.profiler <https://pytorch.org/docs/stable/profiler.html>`_ ,输出的性能分析 trace 文件可以使用 `tensorboard <https://www.tensorflow.org>`_ 进行可视化。 InternLM 使用 ``internlm.train.initialize_llm_profile()`` 来收集和分析模型训练或推理期间的性能数据,如 CPU/CUDA/memory 等性能数据。这个实现基于 `torch.profiler <https://pytorch.org/docs/stable/profiler.html>`_ ,输出的性能分析 trace 文件可以使用 `tensorboard <https://www.tensorflow.org/tensorboard?hl=en>`_ 进行可视化。
用户如果想使用这个 torch 性能分析工具,需要在启动训练时传递 ``--profiling`` 参数以启用性能分析。完成 torch 性能分析后,用户可以在 ``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` 文件夹中看到性能分析结果。 用户如果想使用这个 torch 性能分析工具,需要在启动训练时传递 ``--profiling`` 参数以启用性能分析。完成 torch 性能分析后,用户可以在 ``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` 文件夹中看到性能分析结果。

View File

@ -1,2 +1,2 @@
问&答 问&答
==== =====

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

View File

@ -4,6 +4,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
import json import json
from collections import deque
from typing import Iterable, Optional from typing import Iterable, Optional
from internlm.core.engine import Engine from internlm.core.engine import Engine
@ -58,6 +59,24 @@ class TrainState:
if batch_sampler: if batch_sampler:
self.init_batch_sampler(batch_sampler) self.init_batch_sampler(batch_sampler)
# tgs statistic
self.tgs_statistic = {
"sum_step": 0,
"sum_tg": 0,
"sum_time": 0,
"sum_last_tg_10": 0,
"sum_last_time_10": 0,
"sum_last_tg_50": 0,
"sum_last_time_50": 0,
"SMA_tg_50": 0,
"SMA_time_50": 0,
"SMA_tg_50_list": deque(),
"SMA_time_50_list": deque(),
"sum_tgs": 0,
"last_tgs_10": 0,
"last_tgs_50": 0,
}
def init_batch_sampler(self, batch_sampler): def init_batch_sampler(self, batch_sampler):
""" """
Args: Args:

View File

@ -372,9 +372,52 @@ def record_current_batch_training_metrics(
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
time_cost = time.time() - start_time
tk_per_gpu = 0
tk_per_gpu = round( tk_per_gpu = round(
num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL),
4,
)
tgs_statistic = train_state.tgs_statistic
tgs_statistic["sum_step"] += 1
tgs_statistic["sum_tg"] += tk_per_gpu
tgs_statistic["sum_time"] += time_cost
tgs_statistic["sum_last_tg_10"] += tk_per_gpu
tgs_statistic["sum_last_time_10"] += time_cost
tgs_statistic["sum_last_tg_50"] += tk_per_gpu
tgs_statistic["sum_last_time_50"] += time_cost
tgs_statistic["SMA_tg_50"] += tk_per_gpu
tgs_statistic["SMA_time_50"] += time_cost
tgs_statistic["SMA_tg_50_list"].append(tk_per_gpu)
tgs_statistic["SMA_time_50_list"].append(time_cost)
if tgs_statistic["sum_step"] > 50:
tgs_statistic["SMA_tg_50"] -= tgs_statistic["SMA_tg_50_list"][0]
tgs_statistic["SMA_time_50"] -= tgs_statistic["SMA_time_50_list"][0]
tgs_statistic["SMA_tg_50_list"].popleft()
tgs_statistic["SMA_time_50_list"].popleft()
last_tgs_1 = round(tk_per_gpu / time_cost, 2)
tgs_statistic["sum_tgs"] += last_tgs_1
if tgs_statistic["sum_step"] % 10 == 0:
tgs_statistic["last_tgs_10"] = round(tgs_statistic["sum_last_tg_10"] / tgs_statistic["sum_last_time_10"], 2)
tgs_statistic["sum_last_tg_10"] = 0
tgs_statistic["sum_last_time_10"] = 0
if tgs_statistic["sum_step"] % 50 == 0:
tgs_statistic["last_tgs_50"] = round(tgs_statistic["sum_last_tg_50"] / tgs_statistic["sum_last_time_50"], 2)
tgs_statistic["sum_last_tg_50"] = 0
tgs_statistic["sum_last_time_50"] = 0
last_tgs_10 = tgs_statistic["last_tgs_10"]
last_tgs_50 = tgs_statistic["last_tgs_50"]
tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["sum_time"], 2)
tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2)
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
tflops = get_tflops_func((time.time() - start_time))
tgs_origin = round(
num_tokens_in_batch num_tokens_in_batch
* gpc.get_world_size(ParallelMode.DATA) * gpc.get_world_size(ParallelMode.DATA)
/ gpc.get_world_size(ParallelMode.GLOBAL) / gpc.get_world_size(ParallelMode.GLOBAL)
@ -382,13 +425,17 @@ def record_current_batch_training_metrics(
2, 2,
) )
tflops = get_tflops_func((time.time() - start_time))
infos = { infos = {
"tflops": tflops, "tflops": tflops,
"step": batch_count, "step": batch_count,
"loss": loss.item(), "loss": loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu, "tgs (tokens/gpu/second)": tgs_origin,
"tgs/last_tgs_1": last_tgs_1,
"tgs/tgs_all": tgs_all,
"tgs/tgs_avg": tgs_avg,
"tgs/tgs_SMA": tgs_SMA,
"tgs/last_tgs_10": last_tgs_10,
"tgs/last_tgs_50": last_tgs_50,
"lr": lr, "lr": lr,
"loss_scale": scaler, "loss_scale": scaler,
"grad_norm": grad_norm, "grad_norm": grad_norm,
@ -428,7 +475,7 @@ def record_current_batch_training_metrics(
"num_consumed_tokens": train_state.num_consumed_tokens, "num_consumed_tokens": train_state.num_consumed_tokens,
"loss": loss.item(), "loss": loss.item(),
"flops": tflops, "flops": tflops,
"tgs": tk_per_gpu, "tgs": last_tgs_1,
"acc": acc_perplex["acc"], "acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"], "perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time, "fwd_bwd_time": fwd_bwd_time,

View File

@ -0,0 +1,65 @@
import multiprocessing as mp
import pytest
import torch
from internlm.model.embedding import Embedding1D
from tests.test_model.test_model_internlm import build_environment, seed_all
def check_embedding(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
vocab_size = 4
hidden_size = 2
# fix seed
seed_all(1024)
# define embedding
embedding = Embedding1D(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
padding_idx=None,
)
embedding.weight.data.copy_(torch.randn(vocab_size, hidden_size))
embedding = embedding.to(device)
# create input
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
result = embedding(input_ids)
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
standard_result = torch.tensor(standard_list).to(device)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol, equal_nan=True)
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = embedding.weight.grad
standard_glist = [[-0.4461, 0.5602], [0.4353, 1.2988], [-0.0625, -1.3609], [0.9595, -0.1144]]
standard_grad = torch.tensor(standard_glist).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol, equal_nan=True)
@pytest.mark.embedding
def test_embedding():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_embedding, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_embedding.py"])

View File

@ -0,0 +1,379 @@
import multiprocessing as mp
import random
import numpy as np
import pytest
import torch
from torch import nn
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.linear import RewardModelLinear, ScaleColumnParallelLinear
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D
from internlm.model.utils import gather_forward_split_backward
config = Config(
dict(
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
model_type="INTERNLM",
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
checkpoint=False,
num_attention_heads=2,
embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=1024,
num_layers=2,
mlp_ratio=1,
apply_post_layer_norm=False,
dtype=torch.bfloat16,
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1,
),
resume_tb_folder="",
tensorboard_folder="",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
)
)
def build_environment(rank, world_size):
import os
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def check_block(args):
# init
rank, world_size = args
build_environment(rank, world_size)
device = torch.device("cuda")
rtol, atol = (1e-3, 5e-3)
# fix seed
seed_all(1024)
# define block
blocks = nn.ModuleList(
[
PackedFlashBaseLayer1D(
hidden_size=4, # 768
num_attention_heads=2, # 12
mlp_ratio=2,
attn_drop_rate=0.0,
drop_rate=0.0,
dtype=torch.bfloat16,
layer_norm_epsilon=1e-5,
checkpoint=lid < 0,
layer_idx=lid + 0, # This parameter is used for caching during generation
residual_in_fp32=False,
device=device,
norm_type="rmsnorm",
dropout_selective_checkpoint=True,
use_scaled_init=True,
use_swiglu=True,
)
for lid in range(4) # 32
]
)
# create input
cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16]
indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
hidden_states = torch.tensor([[0, 3, 2, 1]]).to(device) # [[4, 118, 0, 1, 2, 3, 0, 1, 1, 97, 0, 0, 0, 0, 0, 0]]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
hidden_states = torch.tensor(
[
[
[-1.1620, 1.3113, 0.1507, 2.2698],
[-1.2610, 1.0990, 0.3787, -0.3478],
[1.4001, 1.1982, -0.6696, 0.3269],
[1.3304, 1.2262, 1.0735, -1.1169],
]
]
)
hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()
# forward
for _, block in enumerate(blocks):
block = block.to(torch.bfloat16)
block = block.to(device)
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=None,
max_seqlen=max_seqlen,
)
result = hidden_states
standard_result = torch.tensor(
[
[-1.1621, 1.3111, 0.1509, 2.2697],
[-1.2611, 1.0988, 0.3787, -0.3478],
[1.4000, 1.1982, -0.6694, 0.3268],
[1.3303, 1.2262, 1.0736, -1.1169],
]
).to(device)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
standard_grad = torch.tensor(
[
[0.7999, -0.2595, 0.2649, -1.3256],
[0.7064, 0.0283, -0.5508, 0.6494],
[-1.4657, -2.0316, 1.3776, 0.7211],
[-0.6046, 0.4329, -0.1884, 1.1170],
]
).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
def check_head(args):
# init
rank, world_size, is_reward = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
hidden_size = 4
vocab_size = 4
embed_grad_scale = 1
# fix seed
seed_all(1024)
# load standard
if is_reward:
head_cls = RewardModelLinear
standard_result = torch.tensor([[3.5938], [1.0703], [3.6250], [3.6250]], dtype=torch.bfloat16).to(device)
standard_grad = torch.tensor(
[
[-0.2246, 0.0164, -0.0591, 0.1660],
[-0.5625, 0.0408, -0.1484, 0.4160],
[-0.1758, 0.0128, -0.0464, 0.1299],
[-0.4785, 0.0347, -0.1260, 0.3516],
],
dtype=torch.bfloat16,
).to(device)
else:
head_cls = ScaleColumnParallelLinear
standard_result = torch.tensor(
[
[3.5938, -2.2188, 2.0312, 3.5625],
[1.0703, -1.1797, 1.1406, 1.6641],
[3.6250, -2.0156, 1.7656, 3.4531],
[3.6250, -2.0156, 1.7656, 3.4531],
],
dtype=torch.bfloat16,
).to(device)
standard_grad = torch.tensor(
[
[-0.2354, 0.0981, -0.2930, -0.6328],
[0.2344, -0.2334, -0.0918, 0.1396],
[-0.5898, -1.0156, -0.7070, 1.3750],
[0.0242, -0.1494, 0.1206, -0.0427],
],
dtype=torch.bfloat16,
).to(device)
# define head
head = head_cls(
in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=torch.bfloat16,
weight_scale=embed_grad_scale,
)
head = head.to(torch.bfloat16)
head = head.to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
dtype=torch.bfloat16,
requires_grad=True,
).to(device)
# forward
result = head(hidden_states)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
def check_gather_forward(args):
# init
rank, world_size, parallel_tensor = args
assert parallel_tensor in [1, 2]
config.parallel.tensor = parallel_tensor
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
# fix seed
seed_all(1024)
# load standard
if parallel_tensor == 1:
standard_result = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
]
).to(device)
standard_grad = torch.tensor(
[
[-0.4461, 0.5602, -0.0625, -1.3609],
[0.4353, 1.2988, 0.9595, -0.1144],
[-0.7593, -0.4031, 0.2041, 1.4955],
[0.5706, 0.9047, -0.6965, -0.3757],
]
).to(device)
else:
standard_result = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000, 8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000, 3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000, 8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000, 8.3726, 1.2875, 5.5101, 1.0000],
]
).to(device)
if rank % 2 == 0:
standard_grad = torch.tensor(
[
[-0.4461, 0.5602, -0.0625, -1.3609],
[-0.7593, -0.4031, 0.2041, 1.4955],
[0.8093, 1.7580, 1.2996, -0.7545],
[1.0474, -0.5767, -1.0401, 0.8233],
]
).to(device)
else:
standard_grad = torch.tensor(
[
[0.4353, 1.2988, 0.9595, -0.1144],
[0.5706, 0.9047, -0.6965, -0.3757],
[-1.3589, -0.7202, 0.6094, -0.8208],
[-1.0042, 0.3695, 0.2511, -0.2718],
]
).to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
requires_grad=True,
).to(device)
# forward
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
loss = torch.randn_like(result)
hidden_states.retain_grad()
# backward
result.backward(loss)
grad = hidden_states.grad
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
@pytest.mark.block
def test_block():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_block, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
@pytest.mark.head
@pytest.mark.parametrize("is_reward", [True, False])
def test_head(is_reward):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_head, [[rank, 8, is_reward] for rank in range(8)])
pool.close()
pool.join()
@pytest.mark.gather_forward
@pytest.mark.parametrize("parallel_tensor", [1, 2])
def test_gather_forward(parallel_tensor):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_gather_forward, [[rank, 8, parallel_tensor] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_model_internlm.py"])

View File

@ -0,0 +1,84 @@
import multiprocessing as mp
import pytest
import torch
from internlm.model.utils import try_import_RMSNorm
from tests.test_model.test_model_internlm import build_environment, seed_all
RMSNorm = try_import_RMSNorm()
def check_norm(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
hidden_size = 4
layer_norm_epsilon = 1e-05
# fix seed
seed_all(1024)
# define norm
norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
norm = norm.to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
requires_grad=True,
).to(device)
# forward
result = norm(hidden_states.float())
standard = torch.tensor(
[
[1.6329, 0.3753, 1.0746, 0.1950],
[1.4288, 1.2626, 0.4268, 0.4268],
[1.6490, 0.2536, 1.0852, 0.1970],
[1.6490, 0.2536, 1.0852, 0.1970],
]
).to(device)
# check output
assert torch.allclose(result, standard, rtol=rtol, atol=atol, equal_nan=True)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
standard_grad = torch.tensor(
[
[-0.0193, 0.1248, 0.0324, -0.2573],
[-0.2140, 0.2010, 0.2901, -0.1683],
[-0.0815, -0.0689, 0.0850, 0.3027],
[0.0847, 0.1739, -0.1554, -0.0773],
]
).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol, equal_nan=True)
@pytest.mark.norm
def test_norm():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_norm, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_norm.py"])

View File

@ -0,0 +1,364 @@
import copy
import multiprocessing as mp
import random
import numpy as np
import pytest
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import internlm
from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
config = Config(
dict(
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
model_type="INTERNLM",
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
dtype=torch.bfloat16,
),
resume_tb_folder="",
tensorboard_folder="",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
grad_scaler=dict(
fp16=dict(
initial_scale=1,
min_scale=1,
growth_interval=1,
),
growth_factor=1.1,
backoff_factor=0.9,
max_scale=1,
hysteresis=1,
),
adam=dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
),
hybrid_zero_optimizer=dict(
overlap_sync_grad=False,
overlap_sync_param=False,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
),
)
)
def build_environment(rank, world_size):
import os
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
if dtype is torch.float32:
rtol = 1.3e-6
atol = 1e-5
elif dtype is torch.bfloat16:
rtol = 2e-2
atol = 2e-2
if isinstance(a, torch.Tensor):
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
def init_optimizer_grouped_parameters(check_group, model):
if check_group:
optimizer_grouped_parameters = [
{
"params": list(model.parameters())[:2],
"weight_decay": config.adam.weight_decay,
},
{
"params": list(model.parameters())[2:],
"weight_decay": config.adam.weight_decay,
},
]
else:
optimizer_grouped_parameters = [{"params": model.parameters(), "weight_decay": config.adam.weight_decay}]
return optimizer_grouped_parameters
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def exam_hybrid_zero_optim_with_ddp(args):
# init
rank, world_size, zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype = args
# TODO: Need to test the combine of overlap param and group_params when ready
# ParamBcastSyncHandler does not consider paramters in different optimizer group currently
if overlap_sync_param and check_group:
return
config.parallel.zero1 = zero_parallel
config.hybrid_zero_optimizer.overlap_sync_param = overlap_sync_param
config.hybrid_zero_optimizer.overlap_sync_grad = overlap_sync_grad
config.data.micro_num = micro_num
config.model.dtype = dtype
totel_step = 5
if not overlap_sync_param:
totel_step = 1
build_environment(rank, world_size)
seed_all(1024)
# create models
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer
if config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(zero_model)
else:
param_bcast_sync_handler = None
optimizer_grouped_parameters_zero = init_optimizer_grouped_parameters(check_group, zero_model)
optimizer_grouped_parameters_torch = init_optimizer_grouped_parameters(check_group, torch_model)
naive_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters_zero,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
torch_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters_torch,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
for _ in range(totel_step):
zero_optimizer.zero_grad()
torch_optimizer.zero_grad()
zero_optimizer.skip_grad_reduce = True
for num in range(micro_num):
if num == micro_num - 1:
zero_optimizer.skip_grad_reduce = False
seed_all(1024 + rank)
# create input
input_data = torch.rand(16, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.to(dtype))
# torch-ddp forward
torch_output = torch_model(input_data)
# check output
loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean())
# torch-ddp backward
if num == micro_num - 1:
torch_output.mean().backward()
else:
with torch_model.no_sync():
torch_output.mean().backward()
# zero-dp step
zero_optimizer.step()
# torch-ddp step
torch_optimizer.step()
# check grad
if check_group:
group1 = zip(list(torch_model.parameters())[:2], list(zero_model.parameters())[:2])
group2 = zip(list(torch_model.parameters())[2:], list(zero_model.parameters())[2:])
for torch_parm, zero_parm in group1:
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
for torch_parm, zero_parm in group2:
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
else:
for torch_parm, zero_parm in zip(torch_model.parameters(), zero_model.parameters()):
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
torch.cuda.synchronize()
# check updated param
if check_group:
group1 = zip(list(torch_model.parameters())[:2], list(zero_model.parameters())[:2])
group2 = zip(list(torch_model.parameters())[2:], list(zero_model.parameters())[2:])
for torch_parm, zero_parm in group1:
loose_close(torch_parm, zero_parm, dtype=dtype)
for torch_parm, zero_parm in group2:
loose_close(torch_parm, zero_parm, dtype=dtype)
else:
for torch_parm, zero_parm in zip(torch_model.parameters(), zero_model.parameters()):
loose_close(torch_parm, zero_parm, dtype=dtype)
def exam_hybrid_zero_optim_with_ckpt_load_save(args):
# init
rank, world_size, zero_parallel, check_group, dtype = args
config.parallel.zero1 = zero_parallel
config.parallel.dtype = dtype
build_environment(rank, world_size)
# create models
zero_model = MlpModel().cuda().to(dtype)
# create optimizer
if config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(zero_model)
else:
param_bcast_sync_handler = None
optimizer_grouped_parameters1 = init_optimizer_grouped_parameters(check_group, zero_model)
optimizer_grouped_parameters2 = init_optimizer_grouped_parameters(check_group, zero_model)
naive_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters1,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
naive_optimizer2 = torch.optim.AdamW(
params=optimizer_grouped_parameters2,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer2 = HybridZeroOptimizer(
naive_optimizer2,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
# save and load states
states = zero_optimizer.state_dict()
zero_optimizer2.load_state_dict(states)
# check fp32 model weights
for zero1_param, zero2_param in zip(
zero_optimizer._fp32_flat_param_groups_of_current_rank.values(),
zero_optimizer2._fp32_flat_param_groups_of_current_rank.values(),
):
assert torch.equal(zero1_param, zero2_param)
# check fp16 model weights
for zero1_param, zero2_param in zip(
zero_optimizer._fp16_param_groups.values(), zero_optimizer2._fp16_param_groups.values()
):
assert zero1_param == zero2_param
zero_parallel_check_list = [-1, 1, 4]
overlap_sync_param_check_list = [True, False]
overlap_sync_grad_check_list = [True, False]
miro_num_check_list = [1, 2, 4]
check_group_list = [True, False]
dtype_list = [torch.float32, torch.bfloat16]
@pytest.mark.parametrize("zero_parallel", zero_parallel_check_list)
@pytest.mark.parametrize("overlap_sync_param", overlap_sync_param_check_list)
@pytest.mark.parametrize("overlap_sync_grad", overlap_sync_grad_check_list)
@pytest.mark.parametrize("micro_num", miro_num_check_list)
@pytest.mark.parametrize("check_group", check_group_list)
@pytest.mark.parametrize("dtype", dtype_list)
def test_hybrid_zero_optim_with_ddp(
zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype
):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(
exam_hybrid_zero_optim_with_ddp,
[
[rank, 8, zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype]
for rank in range(8)
],
)
pool.close()
pool.join()
@pytest.mark.parametrize("zero_parallel", zero_parallel_check_list)
@pytest.mark.parametrize("check_group", check_group_list)
@pytest.mark.parametrize("dtype", dtype_list)
def test_hybrid_zero_optim_with_ckpt_load_save(zero_parallel, check_group, dtype):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(
exam_hybrid_zero_optim_with_ckpt_load_save,
[[rank, 8, zero_parallel, check_group, dtype] for rank in range(8)],
)
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_optimizer.py"])