Merge remote-tracking branch 'origin/develop'

pull/299/head v0.2.1dev20230908
yingtongxiong 2023-09-08 17:58:04 +08:00
commit 9481df976f
39 changed files with 1960 additions and 479 deletions

View File

@ -1,4 +1,5 @@
JOB_NAME = "7b_train"
DO_ALERT = False
SEQ_LEN = 2048
HIDDEN_SIZE = 4096
@ -22,13 +23,16 @@ CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
load_ckpt_folder="local:llm_ckpts/",
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ckpt_type means the type of checkpoint to be loaded, now only 'normal' type is supported.
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
@ -52,6 +56,8 @@ data = dict(
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
diag_outlier_ratio=1.1,
)
grad_scaler = dict(
@ -147,3 +153,12 @@ parallel = dict(
cudnn_deterministic = False
cudnn_benchmark = False
monitor = dict(
# feishu alert configs
alert=dict(
enable_feishu_alert=DO_ALERT,
feishu_alert_address=None, # feishu webhook to send alert message
light_monitor_address=None, # light_monitor address to send heartbeat
),
)

View File

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: InternLM \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-07 14:15+0800\n"
"POT-Creation-Date: 2023-09-08 15:32+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
@ -19,26 +19,29 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../source/initialize.rst:2 b829330eebd24620b745072bbfc26c98
#: ../../source/initialize.rst:2
msgid "训练构建"
msgstr "Training Setup"
#: ../../source/initialize.rst:7 8c8472b4647a4de8998d75b9ec6f09eb
#: ../../source/initialize.rst:7
msgid "命令行参数解析"
msgstr "Argument Parsing"
#: ../../source/initialize.rst:8 f74176fa4aee4bbfaf989ffab9283ee7
#: ../../source/initialize.rst:9
#, fuzzy
msgid ""
"InternLM 使用 `argparse <https://docs.python.org/3/library/argparse.html>`_"
" 库来向InternLM运行时提供命令行参数配置。用户可 使用 "
" 库来向InternLM运行时提供命令行参数配置。用户可使用 "
"``internlm.initialize.get_default_parser()`` 来获取 InternLM "
"的默认解析器,其中包含一些内置参数,用户可以向此解析器添加自定义参数。"
msgstr ""
"InternLM uses the `argparse <https://docs.python.org/3/library/argparse.html>`_ library to supply commandline "
"configuration to the InternLM runtime. Use ``internlm.initialize.get_default_parser()`` to get InternLM's default "
"parser with some builtin arguments, users can add custom parameters to this parser."
"InternLM uses the `argparse "
"<https://docs.python.org/3/library/argparse.html>`_ library to supply "
"commandline configuration to the InternLM runtime. Use "
"``internlm.initialize.get_default_parser()`` to get InternLM's default "
"parser with some builtin arguments, users can add custom parameters to "
"this parser."
#: 9930855b85bf41ed8712fc40e1e034f7
#: internlm.initialize.launch.get_default_parser:1 of
msgid ""
"Reads user command line and uses an argument parser to parse the input "
@ -46,9 +49,6 @@ msgid ""
" local rank, backend for torch.distributed."
msgstr ""
#: 015003b013e346bea15b4514f2001a25 544472c2ce3c43bfb59317083c6b55c9
#: 7ee60ba1a92a4b9e8174049fb498a4f0 bca7c66f1a5a4517958bcea1e09d5d10
#: f5cbe452ae694c7884ac4596a7735bf6
#: internlm.initialize.initialize_trainer.initialize_trainer
#: internlm.initialize.launch.get_default_parser
#: internlm.train.training_internlm.get_train_data_loader
@ -57,55 +57,50 @@ msgstr ""
msgid "返回"
msgstr ""
#: 9b04c3d6b98b44ee89f800b71e8d80a9
#: internlm.initialize.launch.get_default_parser:4 of
msgid ""
"Returns the parser with the default arguments, the user may add "
"customized arguments into this parser."
msgstr ""
#: 147005b197e64c4b9a96a7cfe78045bc 3634f79c9aa547a48eb3fd7f150deb51
#: d3f0aa4143c84b719cd0b53170dd86c1
#: internlm.initialize.initialize_trainer.initialize_trainer
#: internlm.initialize.launch.get_default_parser
#: internlm.train.training_internlm.initialize_model of
msgid "返回类型"
msgstr ""
#: ../../source/initialize.rst:25 db2bf9d3ff81483dbf218e63dd4bbbe4
#: ../../source/initialize.rst:25
msgid "模型初始化"
msgstr "Model Initialization"
#: 5c2e33e254d4495fbc4b0226aac1fddb
#: internlm.train.training_internlm.initialize_model:1 of
msgid "Initialize model with Automatic Mixed Precision."
msgstr ""
#: c1254615508542b680daf73374844f9e
#: internlm.train.training_internlm.initialize_model:3 of
msgid "The neural network model to be trained or evaluated."
msgstr ""
#: ../../source/initialize.rst:29 b9867771b9da40cd8f3a55ee5ab95f65
#: ../../source/initialize.rst:29
msgid "InternLM 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下:"
msgstr ""
"InternLM uses the field ``model_type`` and ``model`` in the config file "
"to control model initialization process. An example model initialization "
"configuratio"
#: ../../source/initialize.rst:57 984a38d7f63949ecbb0d8b2ef3459d57
#: ../../source/initialize.rst:57
msgid "字段 ``model_type`` 指明了要初始化的模型类型"
msgstr ""
"The field ``model_type`` specifics the model type has been registered and"
" to be initialized."
#: ../../source/initialize.rst:58 9f04ad0f145f4e40bc75a3ef45c7a59d
#: ../../source/initialize.rst:58
msgid "字段 ``model`` 中的参数指定了在模型初始化过程中的参数设置"
msgstr ""
"The parameters in field ``model`` specific the configuration settings "
"during model initialization."
#: ../../source/initialize.rst:60 d7780e355bb6429bb5151d9a0e6d7e36
#: ../../source/initialize.rst:60
msgid ""
"值得注意的是,用户可以定义新的模型类型,并使用装饰器 ``@MODEL_INITIALIZER.register_module`` "
"注册模型的初始化函数,其中 ``MODEL_INITIALIZER`` 是类 "
@ -117,109 +112,90 @@ msgstr ""
" instantiated object of class ``internlm.util.registry.Registry``, the "
"example is shown as follows."
#: ../../source/initialize.rst:72 d863f71b208a49a09d2d00537e331962
#: ../../source/initialize.rst:72
msgid "优化器初始化"
msgstr "Optimizer Initialization"
#: acaafdc9bb96434bbd42a98f74187db1
#: internlm.train.training_internlm.initialize_optimizer:1 of
msgid "Initialize optimizer."
msgstr ""
#: 62fc4215c9a44bda8b31c933db90f270 93c398e44f6a4f708ba064250a3d253c
#: e2bebdd751724915a65dec444bb89e25
#: internlm.initialize.initialize_trainer.initialize_trainer
#: internlm.train.training_internlm.get_train_data_loader
#: internlm.train.training_internlm.initialize_optimizer of
msgid "参数"
msgstr ""
#: 2033ee96ded8423a80268b337ba9549c
#: internlm.train.training_internlm.initialize_optimizer:3 of
msgid "Your model instance to be trained or evaluated."
msgstr ""
#: df01b44c724b4326a6c85b44694262ba
#: internlm.train.training_internlm.initialize_optimizer:6 of
msgid "A tuple of (optimizer, beta2_scheduler, lr_scheduler)."
msgstr ""
#: ../../source/initialize.rst:79 0b46b890048f4758a9d56e0540759d9f
#: ../../source/initialize.rst:79
msgid "数据加载器初始化"
msgstr "Dataloader Initialization"
#: 58e39b26ab4849788e792df386f01d7e
#: internlm.train.training_internlm.get_train_data_loader:1 of
msgid "Generate and return the training data loader."
msgstr ""
#: 37a91c167e0b4e5fad4edcc3caf0d012
#: internlm.train.training_internlm.get_train_data_loader:3 of
msgid "number of subprocesses used for dataloader."
msgstr ""
#: 947aba2a4f86420d9b2660425a6043cc
#: internlm.train.training_internlm.get_train_data_loader:5 of
msgid "generate function for dataset."
msgstr ""
#: 8a8f5ee665cb4e15bc33194c0b1f346c
#: internlm.train.training_internlm.get_train_data_loader:7 of
msgid "dataset sampler for training dataloader."
msgstr ""
#: 4c3e1e896e7940bf97c124909d2e7f36
#: internlm.train.training_internlm.get_train_data_loader:9 of
msgid "collate function for training dataloader."
msgstr ""
#: d9f0740d048c48888e82c8f8a78e33cd
#: internlm.train.training_internlm.get_train_data_loader:12 of
msgid "A tuple of (train_dl, dataset_types)."
msgstr ""
#: ../../source/initialize.rst:86 1c4df708ff5c47f6abae32617bf2ed31
#: ../../source/initialize.rst:86
msgid "Trainer 初始化"
msgstr "Trainer Initialization"
#: d535583dbcb245499e19c09f3f8b534a
#: internlm.initialize.initialize_trainer.initialize_trainer:1 of
msgid ""
"Core function to wrap the essential training components with our "
"functionality based on the config which is loaded into gpc.config."
msgstr ""
#: 3e370234e4b245e4b9cae1fe235df8ff
#: internlm.initialize.initialize_trainer.initialize_trainer:4 of
msgid "Your model instance or a function to build the model."
msgstr ""
#: b716a4a264234011a7b51fa12e575651
#: internlm.initialize.initialize_trainer.initialize_trainer:6 of
msgid "Your optimizer for training."
msgstr ""
#: 6a54ce9d516f4f14bab281c9db9816e8
#: internlm.initialize.initialize_trainer.initialize_trainer:8 of
msgid "Your criterion instance."
msgstr ""
#: ff9dfd04d31b4dc6afbdd841829b4c33
#: internlm.initialize.initialize_trainer.initialize_trainer:10 of
msgid "Dataloader for training."
msgstr ""
#: de345f9a457a4a88bf60b4ee96535e31
#: internlm.initialize.initialize_trainer.initialize_trainer:12 of
msgid "Dataloader for testing."
msgstr ""
#: 64e646b25420424d9dcdfb1ad7de5e6f
#: internlm.initialize.initialize_trainer.initialize_trainer:14 of
msgid "Your lr scheduler instance, optional."
msgstr ""
#: 39c7132bfafe4e22ae373081fee711ce
#: internlm.initialize.initialize_trainer.initialize_trainer:17 of
msgid ""
"A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``"

View File

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: InternLM \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-07 10:56+0800\n"
"POT-Creation-Date: 2023-09-08 15:32+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: en\n"
@ -19,122 +19,147 @@ msgstr ""
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../source/profiler.rst:2 81b1b5f4414449dfaf107815a911f300
#: ../../source/profiler.rst:2
msgid "性能分析"
msgstr "Profiler"
#: ../../source/profiler.rst:7 d709646ebb314e9abb6a4839a21180bd
#: ../../source/profiler.rst:7
msgid "Torch Profiler"
msgstr ""
#: ../../source/profiler.rst:9 4b5b73486c794c7a9168ad19999e12e1
#: ../../source/profiler.rst:9
msgid ""
"InternLM 使用 ``internlm.train.initialize_llm_profile()`` "
"来收集和分析模型训练或推理期间的性能数据,如 CPU/CUDA/memory 等性能数据。这个实现基于 `torch.profiler "
"<https://pytorch.org/docs/stable/profiler.html>`_ ,输出的性能分析 trace 文件可以使用 "
"`tensorboard <https://www.tensorflow.org>`_ 进行可视化。"
msgstr ""
"InternLM uses ``internlm.train.initialize_llm_profile()`` to profile performance data, execution time duration and breakdown analysis of "
"step time. The implementation is based on `torch.profiler <https://pytorch.org/docs/stable/profiler.html>`_ and output tracing files can "
"be visualized with `tensorboard <https://www.tensorflow.org>`_."
"InternLM uses ``internlm.train.initialize_llm_profile()`` to profile "
"performance data, execution time duration and breakdown analysis of step "
"time. The implementation is based on `torch.profiler "
"<https://pytorch.org/docs/stable/profiler.html>`_ and output tracing "
"files can be visualized with `tensorboard <https://www.tensorflow.org>`_."
#: ../../source/profiler.rst:11 40ff4289735c43fdbeca871b65e82be4
#: ../../source/profiler.rst:11
msgid ""
"用户如果想使用这个 torch 性能分析工具,需要在启动训练时传递 ``--profiling`` 参数以启用性能分析。完成 torch "
"性能分析后,用户可以在 ``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` "
"文件夹中看到性能分析结果。"
msgstr ""
"To use this torch profiler tool, you need to enable profiling by passing the ``--profiling`` flag when starting training. After torch "
"profiling is completed, you can find the profiling results in the ``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` folder."
"To use this torch profiler tool, you need to enable profiling by passing "
"the ``--profiling`` flag when starting training. After torch profiling is"
" completed, you can find the profiling results in the "
"``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` folder."
#: ../../source/profiler.rst:13
msgid "实际运行生成的 ``Torch Profiler`` 目录结构如下:"
msgstr "The directory structure of ``Torch Profiler`` generated files is as follows:"
#: ../../source/profiler.rst:22
msgid "其中, ``traces`` 可以通过 ``TensorBoard`` 可视化,运行命令"
msgstr "Among them, ``traces`` can be visualized through ``TensorBoard`` and run with the command"
#: ../../source/profiler.rst:29
msgid ""
"在打开的 ``TensorBoard -> PyTorch Profiler -> Views -> Trace`` "
"页面可以看到Operator和GPU Kernel的性能分析时间线如下更多的功能请参考 `torch profiler with "
"tensorboard "
"<https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html"
"#pytorch-profiler-with-tensorboard>`_"
msgstr "In the opened ``TensorBoard -> PyTorch Profiler -> Views -> Trace`` page, you can see the timeline of profiled operators and GPU kernels. For more usage, please refer to `torch profiler with tensorboard <https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html#pytorch-profiler-with-tensorboard>`_"
#: 876a2993b82645f7b56553fe64b514ec
#: internlm.train.training_internlm.initialize_llm_profile:1 of
msgid "Initialize and return the profiler context manager instance."
msgstr ""
#: ../../source/profiler.rst:16 3ab9536155ea4f3b8adb318005970bb8
#: ../../source/profiler.rst:38
msgid "Memory Profiler"
msgstr ""
#: ../../source/profiler.rst:18 0ec4091fef5b47c58488618bfb4dcd3b
#: ../../source/profiler.rst:40
msgid ""
"InternLM 提供了一个实用的内存分析工具 "
"``internlm.utils.simple_memory_profiler.SimpleMemoryProfiler`` 来监控实际的 GPU"
" 内存使用情况。在实现中,会对模型数据(包括模型参数、模型梯度和优化器状态)和非模型数据(包括激活值)分别进行详细的统计。"
msgstr ""
"InternLM provides a practical solution ``internlm.utils.simple_memory_profiler.SimpleMemoryProfiler`` to monitor actual GPU memory usage. "
"In the implmentation, model data (including model parameters, model gradients, and optimizer states) and non-model data "
"(including activations) are calculated."
"InternLM provides a practical solution "
"``internlm.utils.simple_memory_profiler.SimpleMemoryProfiler`` to monitor"
" actual GPU memory usage. In the implmentation, model data (including "
"model parameters, model gradients, and optimizer states) and non-model "
"data (including activations) are calculated."
#: ../../source/profiler.rst:20 cd62bbd5b122480da21e10453b95090c
#: ../../source/profiler.rst:42
msgid ""
"要使用这个内存分析工具,用户需要在启动训练时传递 ``--profiling`` 参数以启用内存分析。完成内存分析后,用户可以在 "
"``memory_trace/rank{}_dp{}_tp{}`` 文件夹中找到特定 rank "
"对应的内存分析结果(包括不同时间点的内存使用日志和显示总体内存使用情况的太阳图表)。"
msgstr ""
"To use this memory profiler tool, you need to enable profiling by passing the ``--profiling`` flag when starting training. After memory "
"profiling is completed, you can find the profiling results (including logs of memory usage at different time point and sunburst charts "
"showing overall memory usage) for a specific rank device in the ``memory_trace/rank{}_dp{}_tp{}`` folder."
"To use this memory profiler tool, you need to enable profiling by passing"
" the ``--profiling`` flag when starting training. After memory profiling "
"is completed, you can find the profiling results (including logs of "
"memory usage at different time point and sunburst charts showing overall "
"memory usage) for a specific rank device in the "
"``memory_trace/rank{}_dp{}_tp{}`` folder."
#: ../../source/profiler.rst:44
msgid "实际运行生成的 ``memory_trace`` 目录结构如下:"
msgstr "The directory structure of ``memory_trace`` generated files is as follows:"
#: ../../source/profiler.rst:107
msgid "其中, ``memory.log`` 的内容示例如下:"
msgstr "An example of ``memory.log`` is as follows:"
#: ../../source/profiler.rst:157
msgid "模型参数的太阳图示例如下:"
msgstr "An example of model parameters sunburst chart is as follows:"
#: a858f1377b714cd5ab0cf749d8dbfeb7
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:1 of
msgid "A memory profiler for a llm model."
msgstr ""
#: 08d4cca2ba154080ba72e7d3fbd2a344 36e25696cf7b4a8ca5472e86fd5eea7e
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point of
msgid "参数"
msgstr ""
#: dea424767bc44ff689d582c67b07d637
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:3 of
msgid "The model to profile."
msgstr ""
#: 4f3892910fa14324810c3f33c6af4fdd
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:5 of
msgid "The optimizer used for training the model."
msgstr ""
#: a698f2f57eef4e47a22faa546c687979
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:7 of
msgid "The file to write the memory state information to."
msgstr ""
#: 448fc2b81c794d228ec4b413356289ea
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:9 of
msgid "number of steps to trace."
msgstr ""
#: 85b3b9d4147547fd89c286f003395469
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point:1 of
msgid "Record the memory state."
msgstr ""
#: d474a46415674d35a2c87c57ebff20ea
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point:3 of
msgid "The options to include in the memory state. Defaults to \"\"."
msgstr ""
#: 16261fe5b1df4b13bd23f76d97caf1be
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point:5 of
msgid "Whether to create a new memory record file. Defaults to False."
msgstr ""
#: 3b18845958204f07a6b80b6afb2221f5 d11f76d03d0d456889dee6d267dd4b74
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.step of
msgid "返回"
msgstr ""
#: 0deeb9555efb4aa798fd9d146826e961 46b50da453f1475a88e096b5d6ed8afb
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point:8
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.step:3 of
msgid "None"
msgstr ""
#: 4f2331ac352d4057a852b013ca688ed3
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.step:1 of
msgid "Update the memory state of the optimizer state."
msgstr ""

View File

@ -10,6 +10,28 @@ InternLM 使用 ``internlm.train.initialize_llm_profile()`` 来收集和分析
用户如果想使用这个 torch 性能分析工具,需要在启动训练时传递 ``--profiling`` 参数以启用性能分析。完成 torch 性能分析后,用户可以在 ``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` 文件夹中看到性能分析结果。
实际运行生成的 ``Torch Profiler`` 目录结构如下:
.. code-block:: bash
# tree ./7b_train/Sep08_11-00-51/traces -L 2
./7b_train/Sep08_11-00-51/traces/
└── rank0_dp0_tp0_pp0
└── SH-IDC1-10-140-1-78_238619.1694142354680.pt.trace.json
其中, ``traces`` 可以通过 ``TensorBoard`` 可视化,运行命令
.. code-block:: bash
# visualize traces with tensorboard and custom port
tensorboard --logdir rank0_dp0_tp0_pp0 --port 10088
在打开的 ``TensorBoard -> PyTorch Profiler -> Views -> Trace`` 页面可以看到Operator和GPU Kernel的性能分析时间线如下更多的功能请参考 `torch profiler with tensorboard <https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html#pytorch-profiler-with-tensorboard>`_
.. figure:: ../../imgs/torch_profiler_trace.png
:scale: 45%
:class: with-border
.. autofunction:: internlm.train.initialize_llm_profile
Memory Profiler
@ -19,5 +41,124 @@ InternLM 提供了一个实用的内存分析工具 ``internlm.utils.simple_memo
要使用这个内存分析工具,用户需要在启动训练时传递 ``--profiling`` 参数以启用内存分析。完成内存分析后,用户可以在 ``memory_trace/rank{}_dp{}_tp{}`` 文件夹中找到特定 rank 对应的内存分析结果(包括不同时间点的内存使用日志和显示总体内存使用情况的太阳图表)。
实际运行生成的 ``memory_trace`` 目录结构如下:
.. code-block:: bash
# tree ./memory_trace -L 2
./memory_trace
├── rank0_dp0_tp0 # Profiling results for a specific rank device
│   ├── activation_memory_sunburst.html # Sunburst chart showing activation memory usage
│   ├── grads_memory_sunburst.html # Sunburst chart showing gradient memory usage
│   ├── memory.log # Log of GPU memory usage at different time points
│   ├── os_memory_sunburst.html # Sunburst chart showing optimizer state memory usage
│   ├── params_memory_sunburst.html # Sunburst chart showing parameter memory usage
│   └── summary_sunburst.html # Sunburst chart showing overall memory usage
├── rank1_dp1_tp0
│   ├── activation_memory_sunburst.html
│   ├── grads_memory_sunburst.html
│   ├── memory.log
│   ├── os_memory_sunburst.html
│   ├── params_memory_sunburst.html
│   └── summary_sunburst.html
├── rank2_dp2_tp0
│   ├── activation_memory_sunburst.html
│   ├── grads_memory_sunburst.html
│   ├── memory.log
│   ├── os_memory_sunburst.html
│   ├── params_memory_sunburst.html
│   └── summary_sunburst.html
├── rank3_dp3_tp0
│   ├── activation_memory_sunburst.html
│   ├── grads_memory_sunburst.html
│   ├── memory.log
│   ├── os_memory_sunburst.html
│   ├── params_memory_sunburst.html
│   └── summary_sunburst.html
├── rank4_dp4_tp0
│   ├── activation_memory_sunburst.html
│   ├── grads_memory_sunburst.html
│   ├── memory.log
│   ├── os_memory_sunburst.html
│   ├── params_memory_sunburst.html
│   └── summary_sunburst.html
├── rank5_dp5_tp0
│   ├── activation_memory_sunburst.html
│   ├── grads_memory_sunburst.html
│   ├── memory.log
│   ├── os_memory_sunburst.html
│   ├── params_memory_sunburst.html
│   └── summary_sunburst.html
├── rank6_dp6_tp0
│   ├── activation_memory_sunburst.html
│   ├── grads_memory_sunburst.html
│   ├── memory.log
│   ├── os_memory_sunburst.html
│   ├── params_memory_sunburst.html
│   └── summary_sunburst.html
└── rank7_dp7_tp0
├── activation_memory_sunburst.html
├── grads_memory_sunburst.html
├── memory.log
├── os_memory_sunburst.html
├── params_memory_sunburst.html
└── summary_sunburst.html
其中, ``memory.log`` 的内容示例如下:
.. code-block:: bash
Memory State:
time: 37.56313228607178
---summary---
total_memory: 55953.56 MB
params_memory: 13965.51 MB, grads_memory: 13965.51 MB, os_params_memory: 3461.52 MB, os_state_memory: 6923.03 MB, activation_memory: 17638.00 MB
Memory State:
time: 38.46969723701477
---summary---
total_memory: 38315.56 MB
params_memory: 13965.51 MB, grads_memory: 13965.51 MB, os_params_memory: 3461.52 MB, os_state_memory: 6923.03 MB, activation_memory: 0.00 MB
---Layout---
params_layout:
layer: param_mem, layer_mem: 0.00 MB, total_mem: 13965.51 MB
layer: param_mem.embedding, layer_mem: 0.00 MB, total_mem: 806.00 MB
layer: param_mem.embedding.weight, layer_mem: 806.00 MB, total_mem: 806.00 MB
layer: param_mem.blocks, layer_mem: 0.00 MB, total_mem: 12353.50 MB
layer: param_mem.blocks.0, layer_mem: 0.00 MB, total_mem: 386.05 MB
layer: param_mem.blocks.0.mixer, layer_mem: 0.00 MB, total_mem: 128.03 MB
layer: param_mem.blocks.0.mixer.Wqkv, layer_mem: 0.00 MB, total_mem: 96.02 MB
layer: param_mem.blocks.0.mixer.Wqkv.weight, layer_mem: 96.00 MB, total_mem: 96.00 MB
layer: param_mem.blocks.0.mixer.Wqkv.bias, layer_mem: 0.02 MB, total_mem: 0.02 MB
layer: param_mem.blocks.0.mixer.out_proj, layer_mem: 0.00 MB, total_mem: 32.01 MB
layer: param_mem.blocks.0.mixer.out_proj.weight, layer_mem: 32.00 MB, total_mem: 32.00 MB
layer: param_mem.blocks.0.mixer.out_proj.bias, layer_mem: 0.01 MB, total_mem: 0.01 MB
layer: param_mem.blocks.0.norm1, layer_mem: 0.00 MB, total_mem: 0.01 MB
layer: param_mem.blocks.0.norm1.weight, layer_mem: 0.01 MB, total_mem: 0.01 MB
layer: param_mem.blocks.0.norm2, layer_mem: 0.00 MB, total_mem: 0.01 MB
layer: param_mem.blocks.0.norm2.weight, layer_mem: 0.01 MB, total_mem: 0.01 MB
layer: param_mem.blocks.0.mlp, layer_mem: 0.00 MB, total_mem: 258.00 MB
layer: param_mem.blocks.0.mlp.w1, layer_mem: 0.00 MB, total_mem: 86.00 MB
layer: param_mem.blocks.0.mlp.w1.weight, layer_mem: 86.00 MB, total_mem: 86.00 MB
layer: param_mem.blocks.0.mlp.w2, layer_mem: 0.00 MB, total_mem: 86.00 MB
layer: param_mem.blocks.0.mlp.w2.weight, layer_mem: 86.00 MB, total_mem: 86.00 MB
layer: param_mem.blocks.0.mlp.w3, layer_mem: 0.00 MB, total_mem: 86.00 MB
layer: param_mem.blocks.0.mlp.w3.weight, layer_mem: 86.00 MB, total_mem: 86.00 MB
......
grads_layout:
......
os_params_layout:
......
os_state_layout:
......
activation_base_layout:
......
模型参数的太阳图示例如下:
.. figure:: ../../imgs/params_memory_sunburst.png
:scale: 50%
:class: with-border
.. autoclass:: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler
:members:

View File

@ -112,19 +112,19 @@ If you want to load a model checkpoint when starting the training, you can confi
```python
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
ckpt = dict(
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save the model and optimizer checkpoints
checkpoint_every=float("inf"), # Save a checkpoint every specified number of steps, default value is inf
load_model_only_folder=MODEL_ONLY_FOLDER, # Path to load the initial model weights, only load model weights without loading optimizer weights, training will start from the first step
load_ckpt_folder=LOAD_CKPT_FOLDER, # Path to load the weights of the model and optimizer for resuming training, training will resume from the specified step
load_optimizer=True, # Whether to load optimizer weights when resuming training, default value is True
# When resuming training from a breakpoint,:
# (1) 'path' is the path of the loaded checkpoint.
# (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
)
```
Note:
- `load_model_only_folder` and `load_ckpt_folder` cannot be set at the same time.
- If the path starts with `local:`, it means the file is stored in the local file system. If it starts with `boto3:`, it means the file is stored in the remote OSS.
The configuration for the model is as follows:

Binary file not shown.

After

Width:  |  Height:  |  Size: 477 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 287 KiB

View File

@ -101,18 +101,17 @@ data = dict(
如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置:
```python
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
ckpt = dict(
save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径
checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint默认值为 inf
load_model_only_folder=MODEL_ONLY_FOLDER, # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始
load_ckpt_folder=LOAD_CKPT_FOLDER, # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
load_optimizer=True, # 断点续训时,是否需要加载优化器权重,默认值为 True
# 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
# content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all"
# ckpt_type 表示加载的模型类型,目前支持: "internlm"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
)
```
注意:
- `load_model_only_folder`与`load_ckpt_folder`不能同时设置
- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上
模型相关关键参数配置如下所示:

View File

@ -18,6 +18,7 @@ import torch.distributed as dist
from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
from . import process_group_initializer as pgroup_initializer
from .process_group_initializer import ParallelMode
@ -36,7 +37,7 @@ class Config(dict):
config (dict): The dict object to be wrapped.
"""
def __init__(self, config: dict = None):
def __init__(self, config: dict = None): # pylint: disable=W0231
if config is not None:
for k, v in config.items():
self._add_item(k, v)
@ -100,7 +101,7 @@ class Config(dict):
module_name = filepath.stem
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
module = source_file.load_module() # pylint: disable=W4902,E1120
module = source_file.load_module() # pylint: disable=W4902,E1120,W1505
# load into config
config = Config()
@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta):
"""
# initialize the default process group
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
dist.init_process_group(
rank=rank,
world_size=world_size,
backend=backend,
init_method=init_method,
timeout=LLM_NCCL_TIMEOUT,
)
# None will give the default global process group for pytorch dist operations
ranks = list(range(world_size))
if use_cpu:
cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None
cpu_group = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else None
)
else:
cpu_group = None
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
@ -526,6 +537,7 @@ class ParallelContext(metaclass=SingletonMeta):
if dpseed_with_tpoffset:
dp_seed = seed + pipeline_offset * 1024
add_seed(ParallelMode.DATA, dp_seed)
add_seed(ParallelMode.DUMMY, dp_seed)
# model parallel seeds are different across ranks
if self.is_initialized(ParallelMode.TENSOR):
@ -533,7 +545,11 @@ class ParallelContext(metaclass=SingletonMeta):
tp_seed = seed + tp_rank + pipeline_offset * 1024
add_seed(ParallelMode.TENSOR, tp_seed)
set_mode(ParallelMode.DATA)
# we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode
# during model construction), this is because the random state will be different in different tensor parallel
# device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
# additional random operations during the RowParallelLinear module building process.
set_mode(ParallelMode.DUMMY)
seeds = get_seeds()
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])

View File

@ -9,6 +9,8 @@ from enum import Enum
import torch.distributed as dist
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
# parallel modes
class ParallelMode(Enum):
@ -35,6 +37,9 @@ class ParallelMode(Enum):
# runntime network test
NETTEST = "nettest"
# dummy mode, only used during mode construction
DUMMY = "dummy"
class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.
@ -106,9 +111,13 @@ class Initializer_Data(ProcessGroupInitializer):
for i in range(self.rank_num_per_dp_group):
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None
@ -158,9 +167,13 @@ class Initializer_Model(ProcessGroupInitializer):
for i in range(self.num_group):
ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None
@ -218,9 +231,13 @@ class Initializer_Pipeline(ProcessGroupInitializer):
)
)
pipe_group_size = len(ranks)
pipe_group = dist.new_group(ranks)
pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else pipe_group
)
else:
group_cpu = None
@ -268,9 +285,13 @@ class Initializer_Tensor(ProcessGroupInitializer):
for i in range(self.num_tensor_parallel_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None
@ -324,9 +345,13 @@ class Initializer_Zero1(ProcessGroupInitializer):
i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
for k in range(self.zero1_parallel_size)
]
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None
@ -373,9 +398,13 @@ class Initializer_Nettest(ProcessGroupInitializer):
rank = i * self.nettest_parallel_size + j
if rank < self.world_size:
ranks.append(rank)
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None

View File

@ -9,6 +9,7 @@ import torch
from internlm.core.engine import Engine
from internlm.utils.common import conditional_context
from internlm.utils.timeout import llm_timeout
from .base_scheduler import BaseScheduler, SchedulerHook
@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
return output, loss
@llm_timeout(func_name="nopp_forward_backward_step")
def forward_backward_step(
self,
engine: Engine,

View File

@ -15,6 +15,7 @@ from internlm.core.engine import Engine
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_current_device, move_to_device
from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
from .base_scheduler import BaseScheduler, SchedulerHook
@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler):
return output, label, accum_loss
@llm_timeout(func_name="nointerleaved_forward_backward_step")
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -1247,6 +1249,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
@llm_timeout(func_name="interleaved_forward_backward_step")
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.

View File

@ -23,7 +23,15 @@ class TrainState:
train_dl (DataLoader): The DataLoader object used for training.
"""
def __init__(self, config) -> None:
def __init__(self, config, batch_sampler) -> None:
"""
Args:
config (Config): internlm config
batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is
asynchronous and prefetched, the batch_sampler state maintained inside the
dataloader are faster then the actual training progress, so we copy the
batch_sampler as the anchor point of ckpt reload.
"""
# The number of batches produced by the data iterator
self.batch_count: int = 0
# Used to store the number of samples consumed in the current epoch
@ -43,9 +51,20 @@ class TrainState:
self.tensorboard_folder = config.tensorboard_folder
def init_batch_sampler(self, train_dl):
# Copy of the batch sampler from the DataLoader
self.batch_sampler = train_dl.batch_sampler.copy()
# learning rate
self.lr = config.adam.lr
# smapler state
if batch_sampler:
self.init_batch_sampler(batch_sampler)
def init_batch_sampler(self, batch_sampler):
"""
Args:
batch_sampler (torch.utils.data.Sampler): sampler.
"""
# make a copy of batch_sampler.
self.batch_sampler = batch_sampler.copy()
# Iterator for the batch sampler
self.batch_sampler_iter = iter(self.batch_sampler)
@ -61,26 +80,22 @@ class TrainState:
return json.dumps(info, indent=4, sort_keys=True)
def load_state_dict(self, other_stuffs, train_dl):
def load_state_dict(self, other_stuffs):
"""
Resumes training from a checkpoint.
Args:
other_stuffs (dict): Other information needed to resume training.
train_dl (DataLoader): The DataLoader object used for training.
"""
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
# compatible with previous checkpoints without this parameter
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
# track the actual updates of sampler when using weighted sampling
if hasattr(self, "batch_sampler"):
self.batch_sampler = train_dl.batch_sampler.copy()
self.batch_sampler_iter = iter(self.batch_sampler)
# Because the ckpt save occurs after updating 'step_count',
# there is no need to increment 'step_count' here (Does our step count start from 0 ?),
# However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume.
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
self.step_count = other_stuffs.get("step_count", self.batch_count)
# resume tensorboard from older tensorboard_folder
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)

View File

@ -10,9 +10,10 @@ import torch
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.monitor import initialize_light_monitor
from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import init_storage_manager
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__)
@ -97,6 +98,13 @@ def args_sanity_check():
if "valid_every" not in data:
data._add_item("valid_every", 0)
if "empty_cache_and_diag_interval" not in data:
data._add_item("empty_cache_and_diag_interval", 50)
if "diag_outlier_ratio" not in data:
data._add_item("diag_outlier_ratio", 1.1)
data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"seq_len: {data.seq_len}")
@ -111,7 +119,7 @@ def args_sanity_check():
# processing the checkpoint config
ckpt = gpc.config.ckpt
if "enable_save_ckpt" not in ckpt:
ckpt._add_item("enable_save_ckpt", False)
ckpt._add_item("enable_save_ckpt", True)
# Saving checkpoint args.
if ckpt.enable_save_ckpt:
@ -137,9 +145,6 @@ def args_sanity_check():
if not ckpt.async_upload:
ckpt._add_item("async_upload_tmp_folder", None)
if "snapshot_ckpt_folder" not in ckpt:
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
if "oss_snapshot_freq" not in ckpt:
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
else:
@ -149,44 +154,23 @@ def args_sanity_check():
ckpt._add_item("async_upload", False)
ckpt._add_item("async_upload_tmp_folder", None)
ckpt._add_item("snapshot_ckpt_folder", None)
ckpt._add_item("snapshot_ckpt_folder", None)
# Loading checkpoint args.
if "load_model_only_folder" not in ckpt:
ckpt._add_item("load_model_only_folder", None)
if "load_ckpt_folder" not in ckpt:
ckpt._add_item("load_ckpt_folder", None)
if "load_optimizer" not in ckpt:
ckpt._add_item("load_optimizer", True)
if "stop_file_path" not in ckpt:
ckpt._add_item("stop_file_path", None)
if "load_given_ckpt" not in ckpt:
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
if "auto_resume" not in ckpt:
# If 'auto_resume' is not given, we set it to True, so internlm can have opportunity
# to auto-load latest checkpoint.
ckpt._add_item("load_given_ckpt", False)
if ckpt.load_given_ckpt:
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
logger.warning(
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
)
ckpt.load_model_only_folder = None
ckpt._add_item("auto_resume", True)
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
# initialization storage manager
init_storage_manager(ckpt)
# tensorboard writer config
if "enable_tb" not in gpc.config:
@ -277,9 +261,22 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting
if "alert_address" not in gpc.config:
gpc.config._add_item("alert_address", None)
# monitoring default config
monitor_default_config = {
"alert_address": None, # compatible with old alert config
"monitor": { # new monitoring config
"alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None}
},
}
for key, value in monitor_default_config.items():
if key not in gpc.config:
gpc.config._add_item(key, value)
alert = gpc.config.monitor.alert
if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log():
logger.warning("alert is enable but alert_address is not set")
optim_ckpt = gpc.config.hybrid_zero_optimizer
if "zero_overlap_communication" in optim_ckpt:
@ -426,6 +423,7 @@ def launch_from_torch(
)
@llm_timeout(func_name="initialize_distributed_env")
def initialize_distributed_env(
config: str,
launcher: str = "slurm",
@ -459,3 +457,20 @@ def initialize_distributed_env(
if args_check:
args_sanity_check()
# init light monitor client
alert_config = gpc.config.monitor.alert
if alert_config.enable_feishu_alert and gpc.is_rank_for_log():
light_monitor_address = alert_config.light_monitor_address
if light_monitor_address:
initialize_light_monitor(light_monitor_address)
else:
logger.warning("monitor address is none, monitor could not be used!")
def get_config_value(config, key, defalut):
try:
value = config[key]
except KeyError:
value = defalut
return value

View File

View File

@ -0,0 +1,40 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from internlm.initialize.launch import get_config_value
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def auto_resume_sanity_check(ckpt_config):
load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
if load_given_ckpt is None:
return True # default value is True
else:
return not load_given_ckpt
def ckpt_info_sanity_check(ckpt_config):
load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
if load_model_only_folder is not None:
assert (
load_ckpt_folder is None
), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
else:
load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
if isinstance(load_ckpt_folder, str):
if load_optimizer:
return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
else:
return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
elif load_ckpt_folder is None:
return None
else:
assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"

View File

@ -9,7 +9,7 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import fused_dense_func_torch
@ -195,12 +195,6 @@ class FeedForward(nn.Module):
device=device,
dtype=dtype,
)
# need to assign tp attribute so that colossalai know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["w1", "w2", "w3"]:
for param in getattr(self, name).parameters():
setattr(param, IS_TENSOR_PARALLEL, True)
def forward(self, x):
out = self.w3(F.silu(self.w1(x)) * self.w2(x))

View File

@ -127,6 +127,9 @@ class PackedFlashBaseLayer1D(nn.Module):
device=device,
dtype=dtype,
)
for _, param in self.mlp.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init

View File

@ -1,4 +1,11 @@
from .alert import initialize_light_monitor, send_heartbeat
from .monitor import initialize_monitor_manager, send_alert_message
from .utils import set_env_var
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]
__all__ = [
"send_alert_message",
"initialize_monitor_manager",
"set_env_var",
"initialize_light_monitor",
"send_heartbeat",
]

View File

@ -1,8 +1,59 @@
import json
import math
import os
import re
import time
from typing import Dict
import requests
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def initialize_light_monitor(monitor_address: str = None):
try:
from uniscale_monitoring import init_monitor
init_monitor(monitor_address)
except Exception as e:
logger.warning(f"init monitor meet error: {e}")
def send_heartbeat(msg_type: str, msg: Dict):
def nan2none(v):
if isinstance(v, float) and math.isnan(v):
return None
return v
try:
from uniscale_monitoring import send_meta
data = {}
for k, v in msg.items():
if isinstance(v, Dict):
for k1, v1 in v.items():
new_k = f"{k}_{k1}".split(" ")[0]
new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k)
data[new_k] = nan2none(v1)
else:
new_k = k.split(" ")[0]
new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k)
data[new_k] = nan2none(v)
if os.getenv("CLUSTER_NAME"):
data.update({"cluster": os.getenv("CLUSTER_NAME")})
if msg_type == "train_metrics":
data.update({"msg_type": "train_metrics"})
elif msg_type == "init_time":
data.update({"msg_type": "init_time"})
elif msg_type == "stage_time":
data.update({"msg_type": "stage_time"})
send_meta(data, timeout=0.1)
except Exception as e:
logger.warning(f"send heartbeat meet error: {e}")
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
"""

View File

@ -226,9 +226,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
yield
finally:
send_alert_message(
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
)
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.")
monitor_manager.stop_monitor()
else:
yield

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .hybrid_zero_optim import HybridZeroOptimizer
from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff
__all__ = ["HybridZeroOptimizer"]
__all__ = ["HybridZeroOptimizer", "reload_zero_fp32_buff"]

View File

@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import (
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.timeout import llm_timeout
from .utils import compute_norm
@ -124,6 +125,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_store = ParameterStore(ParallelMode.ZERO1)
self._grad_store = GradientStore(ParallelMode.DATA)
self._bucket_store = BucketStore(ParallelMode.DATA)
self._bucket_in_progress = []
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
@ -133,6 +135,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._comm_bcast_stream = torch.cuda.Stream()
# gradient scaler
self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale,
@ -231,13 +235,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
self.skip_grad_reduce = False
# initialize communication stream for
# communication-computation overlapping
if self._overlap_sync_grad:
self._comm_stream = torch.cuda.Stream()
else:
self._comm_stream = torch.cuda.current_stream()
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad:
@ -383,34 +380,41 @@ class HybridZeroOptimizer(BaseOptimizer):
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
next_bucket_list = []
# add parameters into bucket for reduction
for tensor_list in grad_buckets_by_dtype:
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
next_bucket_list.append(param_bucket)
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
# here we can also overlap the communication with some memcpy operation caused by bucket.flatten()
for bucket in self._bucket_in_progress:
bucket.commu_handle.wait()
bucket.unflatten_and_copy()
bucket.empty()
self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params()
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
self._bucket_in_progress = next_bucket_list.copy()
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_sync_grad:
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# flatten the tensors and do allreduce
bucket.flatten()
bucket.commu_handle = reduce_tensor(
tensor=bucket.get_flat_tensor(),
dtype=None,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.DATA,
)
with torch.cuda.stream(self._comm_stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(
tensor=flat,
dtype=self.dtype,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.DATA,
)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._zero_local_rank:
bucket.unflatten_and_copy(reduced_flat)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._zero_local_rank:
bucket.set_unflatten_and_copy_flag(flag=True)
def _has_inf_or_nan(self, tensor):
try:
@ -506,6 +510,7 @@ class HybridZeroOptimizer(BaseOptimizer):
return norm
@llm_timeout(func_name="optim_step")
def step(self, closure=None):
"""Performs a single optimization step.
@ -534,10 +539,13 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads
if self._overlap_sync_grad:
# grads in the last bucket is reduced
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# grads in the last bucket is reduced
for bucket in self._bucket_in_progress:
bucket.commu_handle.wait()
bucket.unflatten_and_copy()
bucket.empty()
self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params()
# compute norm for gradients in the last bucket
total_norms = {}
@ -562,6 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# check for overflow
found_inf = False
found_nan = False
# if there is INF values in grades, compute_norm func would also returns -1
# thus, we try to avoid call _check_overflow here
# found_inf = self._check_overflow()
@ -570,21 +579,36 @@ class HybridZeroOptimizer(BaseOptimizer):
if -1 in norms.values():
found_inf = True
if -2 in norms.values():
found_nan = True
loss_scale = float(self.loss_scale.item()) # backup
if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
send_alert_message(
address=gpc.config.alert_address,
address=gpc.config.monitor.alert.feishu_alert_address,
message="Overflow occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, norms
if found_nan:
if gpc.is_rank_for_log():
logger.warning("Nan grad norm occurs, please check it.")
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message="Nan grad norm occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, norms
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
for group_id in range(self.num_param_groups):
@ -624,7 +648,9 @@ class HybridZeroOptimizer(BaseOptimizer):
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
single_grad_partition_groups,
list(global_norm_groups.values()),
loss_scale,
)
# update the parameters
@ -645,7 +671,9 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
self.broadcast_params()
torch.cuda.synchronize()
with torch.cuda.stream(self._comm_bcast_stream):
self.broadcast_params()
timer("step").stop()
@ -771,3 +799,17 @@ class HybridZeroOptimizer(BaseOptimizer):
if "zero_devide_optim_plan" in states:
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
def reload_zero_fp32_buff(optimizer):
# If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value.
# Or we must ensure that loading model weights must be done before zero is initialized.
if isinstance(optimizer, HybridZeroOptimizer):
for group_id, param_group in enumerate(optimizer.optim.param_groups):
if optimizer.param_group_has_params[group_id]:
# flatten fp16 params have already been updated by 'load_model_checkpoint'
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
optimizer._zero_local_rank, group_id
)
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())

View File

@ -249,11 +249,17 @@ class ParameterStore(BaseStore):
if not last_bucket:
if group_id not in self._former_bucket_reduced_param:
return [], []
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
return (
self._former_bucket_reduced_param[group_id],
self._former_bucket_reduced_grad[group_id],
)
else:
if group_id not in self._last_bucket_reduced_param:
return [], []
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
return (
self._last_bucket_reduced_param[group_id],
self._last_bucket_reduced_grad[group_id],
)
def reset_reduced_data_for_compute_norm(self):
self._former_bucket_reduced_param = {}
@ -277,6 +283,9 @@ class TensorBucket:
self._max_size = size
self._current_size = 0
self._bucket = []
self._flat_tensor = None
self._unflatten_and_copy_flag = False
self.commu_handle = None
@property
def max_size(self):
@ -292,6 +301,15 @@ class TensorBucket:
def is_empty(self):
return len(self._bucket) == 0
def set_unflatten_and_copy_flag(self, flag):
self._unflatten_and_copy_flag = flag
def get_unflatten_and_copy_flag(self):
return self._unflatten_and_copy_flag
def get_flat_tensor(self):
return self._flat_tensor
def add_to_bucket(self, tensor, allow_oversize=False):
tensor_size = tensor.numel()
@ -312,11 +330,14 @@ class TensorBucket:
def empty(self):
self._bucket = []
self._size = 0
self._flat_tensor = None
self.commu_handle = None
def flatten(self):
return _flatten_dense_tensors(self._bucket)
self._flat_tensor = _flatten_dense_tensors(self._bucket)
def unflatten_and_copy(self, flat_tensor):
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
def unflatten_and_copy(self):
if self._unflatten_and_copy_flag:
unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)

View File

@ -95,37 +95,34 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
:type parallel_mode: ParallelMode, optional
"""
# use the original dtype
if dtype is None:
dtype = tensor.dtype
# if dtype is None:
assert dtype is None
dtype = tensor.dtype
# cast the data to specified dtype for reduce/all-reduce
if tensor.dtype != dtype:
tensor_to_reduce = tensor.to(dtype)
else:
tensor_to_reduce = tensor
# if tensor.dtype != dtype:
# tensor_to_reduce = tensor.to(dtype)
# else:
# tensor_to_reduce = tensor
world_size = gpc.get_world_size(parallel_mode)
# world_size = gpc.get_world_size(parallel_mode)
# tensor.div_(world_size)
group = gpc.get_group(parallel_mode)
tensor_to_reduce.div_(world_size)
# if rank is None, all reduce will be used
# else, reduce is used
use_all_reduce = dst_rank is None
if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group)
handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True)
else:
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
global_rank = ranks_in_group[dst_rank]
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
handle = dist.reduce(
tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True
)
# recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
local_rank = gpc.get_local_rank(parallel_mode)
if use_all_reduce or dst_rank == local_rank:
tensor.copy_(tensor_to_reduce)
return tensor
return handle
def has_inf_or_nan(tensor):
@ -314,6 +311,9 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
if total_norm == float("inf") or total_norm == -float("inf"):
total_norm = -1
if math.isnan(total_norm):
total_norm = -2
return total_norm

View File

@ -12,6 +12,7 @@ from torch.utils.data import ConcatDataset, DataLoader
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.random import set_mode
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
@ -24,7 +25,7 @@ from internlm.data.packed_dataset import (
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.monitor import set_env_var
from internlm.monitor import send_heartbeat, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
@ -39,10 +40,12 @@ from internlm.utils.parallel import (
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__)
@llm_timeout(func_name="initialize_model")
def initialize_model():
"""
Initialize model with Automatic Mixed Precision.
@ -82,9 +85,14 @@ def initialize_model():
# the same across tensor parallelism.
sync_model_param_within_tp(model)
# Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
# state in the same dp group are all the same.
set_mode(ParallelMode.DATA)
return model
@llm_timeout(func_name="initialize_optimizer")
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
"""
Initialize optimizer.
@ -122,6 +130,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
return optimizer, beta2_scheduler, lr_scheduler
@llm_timeout(func_name="get_train_data_loader")
def get_train_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
):
@ -201,6 +210,7 @@ def get_train_data_loader(
return train_dl, dataset_types
@llm_timeout(func_name="get_validation_data_loader")
def get_validation_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
):
@ -262,6 +272,7 @@ def get_validation_data_loader(
return val_dls
@llm_timeout(func_name="load_new_batch")
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
"""
Load and return the new batch data based on training data loader.
@ -319,6 +330,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
)
@llm_timeout(func_name="record_current_batch_training_metrics")
def record_current_batch_training_metrics(
get_tflops_func,
logger,
@ -342,6 +354,7 @@ def record_current_batch_training_metrics(
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
timer.store_last_timers()
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
@ -404,6 +417,9 @@ def record_current_batch_training_metrics(
else:
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0:
send_heartbeat("train_metrics", infos)
if update_panel:
# metrics shown with dashboard panels
panel_metrics = {
@ -429,4 +445,8 @@ def record_current_batch_training_metrics(
logger.info(line)
# if loss spike occurs, send alert info to feishu
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
mm.monitor_loss_spike(
alert_address=gpc.config.monitor.alert.feishu_alert_address,
step_count=batch_count,
cur_step_loss=loss.item(),
)

View File

@ -9,7 +9,9 @@ import torch.distributed as dist
from flash_attn.modules.mha import FlashSelfAttention, SelfAttention
from torch.utils import benchmark
from internlm.monitor import send_alert_message
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
try:
import GPUtil
@ -24,6 +26,23 @@ from internlm.utils.common import get_current_device
logger = get_logger(__file__)
def empty_cache_and_diag(batch_count, interval=50):
"""empty cuda cache and run diag bench or tests."""
if interval <= 0:
interval = 50
if batch_count % int(interval) == 0:
# there is no need to do diag on the first batch
if batch_count > 0:
if gpc.is_rank_for_log():
logger.info("Empty Cache and Diagnosis GPU/NCCL/Timer ...")
with torch.no_grad():
timer_diagnosis()
bench_gpu()
bench_net()
# do empty_cache after the bench
torch.cuda.empty_cache()
def benchmark_forward(
test_fn,
*inputs,
@ -81,14 +100,78 @@ def get_cpu_temperature():
return cpu_temperature
def timer_diagnosis():
"""Diagnosis running time"""
if len(timer.names) == 0 or len(timer.times) == 0:
return
world_size = gpc.get_world_size(ParallelMode.DATA)
if world_size < 2:
return
# if gpc.is_rank_for_log():
# logger.info("Diagnosis running timers ...")
# detect slow rank compared to other ranks in the same DP group
running_time = torch.Tensor(timer.times).to(device=get_current_device())
avg_time = running_time.detach().clone()
if world_size <= 4:
dist.all_reduce(avg_time, op=torch.distributed.ReduceOp.AVG, group=gpc.get_group(ParallelMode.DATA))
else:
running_time_max = avg_time.detach().clone()
running_time_min = avg_time.detach().clone()
dist.all_reduce(running_time_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA))
dist.all_reduce(running_time_min, op=torch.distributed.ReduceOp.MIN, group=gpc.get_group(ParallelMode.DATA))
dist.all_reduce(avg_time, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA))
avg_time = (avg_time - running_time_max - running_time_min) / (world_size - 2)
diag_result = running_time > avg_time * gpc.config.data.diag_outlier_ratio
diag_result = diag_result.tolist()
avg_time = avg_time.tolist()
for slow, name, time, avg in zip(diag_result, timer.names, timer.times, avg_time):
if slow is False or avg < 0.5:
continue
msg = (
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} is slower than avg on {name}, "
f"Hostname {socket.gethostname()}, "
f"its time {time:.2f}, avg {avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
logger.warning(msg)
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)
# detect slow rank compared to historical timer data
for name, time in zip(timer.names, timer.times):
if name not in timer.hist or len(timer.hist[name]) < 5:
continue
hist_avg = sum(timer.hist[name]) / len(timer.hist[name])
if time > hist_avg * gpc.config.data.diag_outlier_ratio and time > 0.5:
msg = (
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} is slower than hist avg on {name}, "
f"Hostname {socket.gethostname()}, "
f"its time {time:.2f}, hist_avg {hist_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
logger.warning(msg)
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)
def bench_net():
"""Benchmark nccl performance for slow node detection."""
if gpc.get_world_size(ParallelMode.GLOBAL) <= 1:
return
if gpc.is_rank_for_log():
logger.info("benchmarking network speed ...")
# if gpc.is_rank_for_log():
# logger.info("benchmarking network speed ...")
repeats = 100
input_data = torch.randn(
@ -113,20 +196,25 @@ def bench_net():
allreduce_time_avg = allreduce_time / gpc.get_world_size(ParallelMode.GLOBAL)
allreduce_time_avg = float(allreduce_time_avg.item())
if allreduce_time_this >= allreduce_time_avg * 1.05:
logger.warning(
if allreduce_time_this >= allreduce_time_avg * gpc.config.data.diag_outlier_ratio:
msg = (
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} NCCL test is slower than avg, "
f"Hostname {socket.gethostname()}, "
f"allreduce_time {allreduce_time_this:.2f}, avg {allreduce_time_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
logger.warning(msg)
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)
def bench_gpu(use_flash_attn=True):
"""Benchmark single GPU performance for slow node detection."""
if gpc.is_rank_for_log():
logger.info("benchmarking gpu speed ...")
# if gpc.is_rank_for_log():
# logger.info("benchmarking gpu speed ...")
headdim = 64
dim = 2048
@ -154,10 +242,15 @@ def bench_gpu(use_flash_attn=True):
speed_avg = speed / gpc.get_world_size(ParallelMode.GLOBAL)
speed_avg = float(speed_avg.item())
if speed_this <= speed_avg * 0.95:
logger.warning(
if speed_this <= speed_avg / gpc.config.data.diag_outlier_ratio:
msg = (
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} GPU is slower than avg, "
f"Hostname {socket.gethostname()}, "
f"tflops {speed_this:.2f}, avg {speed_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
logger.warning(msg)
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)

View File

@ -84,7 +84,7 @@ def initialize_uniscale_logger(
job_name and launch_time and file_name
), "If file_path is None, job_name, launch_time and file_name must be setted."
log_file_name = file_name
log_folder = os.path.join(job_name, launch_time, "logs")
log_folder = os.path.join("RUN", job_name, launch_time, "logs")
log_dir = os.path.join(log_folder, log_file_name)
file_path = log_dir

View File

@ -16,8 +16,12 @@ class _Timer:
self.start_time = time.time()
self.stream = torch.cuda.current_stream()
def start(self):
def start(self, reset_all=True):
"""Start the timer."""
# need to reset all timers in a new batch
if self.name_ == "one-batch" and reset_all is True:
megatron_timer.reset()
assert not self.started_, "timer has already been started"
self.stream.synchronize()
self.start_time = time.time()
@ -48,7 +52,7 @@ class _Timer:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
self.start(reset_all=False)
return elapsed_
@ -57,12 +61,29 @@ class Timers:
def __init__(self):
self.timers = {}
self.hist = {}
self.names = []
self.times = []
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def store_last_timers(self):
"""Store timers to two list"""
self.names = []
self.times = []
for key, value in self.timers.items():
senconds = round(float(value.elapsed(reset=False)), 4)
self.names.append(key)
self.times.append(senconds)
if key not in self.hist:
self.hist[key] = []
self.hist[key].append(senconds)
if len(self.hist[key]) > 10:
self.hist[key].pop(0)
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,

View File

@ -3,37 +3,136 @@
import copy
import fcntl
import inspect
import os
import socket
import time
from enum import Enum
from typing import Dict
from typing import Callable, Dict, Union
import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
from internlm.initialize.launch import get_config_value
from internlm.initialize.legacy.launch import (
auto_resume_sanity_check,
ckpt_info_sanity_check,
)
from internlm.monitor import send_alert_message
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer import HybridZeroOptimizer, reload_zero_fp32_buff
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.storage_manager import (
get_fns,
get_storage_manager,
init_storage_manager,
llm_load,
llm_save,
try_get_storage_backend,
)
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__)
class CheckpointType(Enum):
class CheckpointSaveType(Enum):
NORMAL_CHECKPOINT = 1
SNAPSHOT_CHECKPOINT = 2
class CheckpointLoadType(Enum):
INTERNLM = "internlm"
# The load method implemented by internlm by default does not use string representation types,
# but uses enumeration types defined in advance.
LOAD_TYPE_DICT = {
"internlm": CheckpointLoadType.INTERNLM,
}
class CheckpointLoadContent:
MODEL = "model"
SAMPLER = "sampler"
OPIMIZER = "optimizer"
SCHEDULAER = "scheduler"
class CheckpointLoadMethod:
"""The registration class of the checkpoint loading method,
users can define their own custom ckpt loading methods."""
LOAD_FUNC_SIG = None
LOAD_TYPE_FUNC = {}
@staticmethod
def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
if load_type.lower() in LOAD_TYPE_DICT:
# The ckpt load method implemented by internlm by default.
return LOAD_TYPE_DICT[load_type.lower()]
else:
# If it is a user-defined field, we do not do any conversion and represent it as a string.
return load_type
@staticmethod
def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable):
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC:
logger.warning(f"{load_type} has aleady been registed!")
return
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
if load_type == CheckpointLoadType.INTERNLM:
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
else:
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG:
logger.warning(
f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}"
)
@staticmethod
def get_ckpt_load_type_func(load_type: Union[str, CheckpointLoadType]):
return CheckpointLoadMethod.LOAD_TYPE_FUNC[load_type]
class CheckpointLoadMask:
"""
According to the content field in the incoming ckpt_info, decide which components to load.
"""
LOAD_CONTENT_DICT = {
"model": CheckpointLoadContent.MODEL,
"sampler": CheckpointLoadContent.SAMPLER,
"optimizer": CheckpointLoadContent.OPIMIZER,
"scheduler": CheckpointLoadContent.SCHEDULAER,
}
def __init__(self, content: tuple) -> None:
self.load_set = set(map(lambda x: x.lower(), content))
if "all" in self.load_set:
self.load_set = set(CheckpointLoadMask.LOAD_CONTENT_DICT.values())
else:
self.load_set = set(map(lambda x: CheckpointLoadMask.LOAD_CONTENT_DICT[x.lower()], content))
def need_load(self, content: CheckpointLoadContent):
return content in self.load_set
def not_only_load(self, content: CheckpointLoadContent):
return content in self.load_set and len(self.load_set) > 1
def only_load(self, content: CheckpointLoadContent):
return set((content,)) == self.load_set
def __str__(self) -> str:
return f"{self.load_set}."
def __repr__(self) -> str:
return f"{self.load_set}."
def get_model_topology(model):
"""
Returns:
@ -55,6 +154,66 @@ def get_model_topology(model):
return topos
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
load_content_str = ""
load_ckpt_folder = load_info["path"]
load_content: CheckpointLoadMask = load_info["content"]
if gpc.is_rank_for_log():
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
if load_content.need_load(CheckpointLoadContent.MODEL):
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
load_content_str += f"{CheckpointLoadContent.MODEL}, "
if load_content.not_only_load(CheckpointLoadContent.MODEL):
# load training states.
load_context(load_ckpt_folder, train_state)
# load optimzier states.
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
else:
if gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")
# load lr scheduler states.
if load_content.need_load(CheckpointLoadContent.SCHEDULAER):
if ckpt_mm.lr_scheduler:
load_scheduler(load_ckpt_folder, ckpt_mm.lr_scheduler, ckpt_mm.optimizer, train_state)
load_content_str += f"{CheckpointLoadContent.SCHEDULAER}, "
else:
if gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
# load dataloader sampler states.
if load_content.need_load(CheckpointLoadContent.SAMPLER):
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
load_sampler(load_ckpt_folder, ckpt_mm.train_dl.batch_sampler)
# track the actual updates of sampler when using weighted sampling
train_state.init_batch_sampler(ckpt_mm.train_dl.batch_sampler)
load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
else:
if gpc.is_rank_for_log():
logger.warning("CheckpointManager skip reload 'batch_sampler'")
# reload data state dict.
if hasattr(train_state, "data_state_dict"):
ckpt_mm.train_dl.dataset.load_state_dict(
llm_load(os.path.join(load_ckpt_folder, "sampler_0.pt")), ckpt_path=load_ckpt_folder
)
load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
else:
if gpc.is_rank_for_log():
logger.warning(
"CheckpointManager has no 'data_state_dict', skip reload data_state_dict checkpoint!"
)
return load_content_str
def save_model_checkpoint(folder, model):
"""
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
@ -233,15 +392,16 @@ def load_sampler(ckpt_path: str, sampler):
torch.cuda.empty_cache()
def load_context(ckpt_path: str, train_dl, train_state: TrainState):
def load_context(ckpt_path: str, train_state: TrainState):
context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt"))
train_state.load_state_dict(context_stuffs, train_dl)
train_state.load_state_dict(context_stuffs)
if gpc.is_rank_for_log():
logger.info(f"reload train_state:{train_state}")
torch.cuda.empty_cache()
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState):
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainState):
learning_rate = train_state.lr
scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt"))
if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log():
logger.warning(
@ -270,7 +430,17 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
class CheckpointManager:
"""StorageManagerContext"""
def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
def __init__(
self,
ckpt_config,
model,
train_dl=None,
optimizer=None,
lr_scheduler=None,
model_config=None,
model_config_file=None,
feishu_address=None,
) -> None:
"""
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait
@ -283,22 +453,44 @@ class CheckpointManager:
lr_scheduler (object): lr_scheduler obj.
model_config (dict): model config.
"""
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
self.checkpoint_every = ckpt_config.checkpoint_every
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
self.stop_file_path = ckpt_config.stop_file_path
self.load_model_only_folder = ckpt_config.load_model_only_folder
self.enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
self.checkpoint_every = get_config_value(ckpt_config, "checkpoint_every", 100)
self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None)
self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50)
self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None)
if self.save_ckpt_folder:
self.snapshot_ckpt_folder = get_config_value(
ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot")
)
self.async_upload_tmp_folder = get_config_value(
ckpt_config, "async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/"
)
else:
self.snapshot_ckpt_folder = None
self.async_upload_tmp_folder = None
self.async_upload = get_config_value(ckpt_config, "async_upload", False)
# initialization storage manager
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload)
self.feishu_address = feishu_address
self.storage_manager = get_storage_manager()
self.snapshot_counter = 0
self.load_optimizer = gpc.config.ckpt.load_optimizer
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dl = train_dl
self.model_config = model_config
self.model_config_file = model_config_file
# Register defalut internlm ckpt load type.
self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt}
for ckpt_load_type in CheckpointLoadType:
CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
# Init alter file.
if self.stop_file_path and gpc.get_global_rank() == 0:
dir_path = os.path.dirname(self.stop_file_path)
if dir_path != "" and not os.path.exists(dir_path):
@ -306,21 +498,35 @@ class CheckpointManager:
with open(self.stop_file_path, "w", encoding="utf-8") as f:
f.write("0")
if ckpt_config.load_given_ckpt is False:
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
latest_ckpt_path = self.query_lastest_ckpt()
if latest_ckpt_path:
self.load_ckpt_folder = latest_ckpt_path
else:
# At this time, we have to load model init weights and train from step 0.
self.load_ckpt_folder = self.load_model_only_folder
else:
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None)
if self.load_ckpt_info is None: # (legacy): Try Compatible with old interfaces
self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config)
if gpc.is_rank_for_log():
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
if self.stop_file_path is None:
logger.warning("no set stop_file_path, quit_signal_handler is disable")
# Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'.
self.auto_resume = get_config_value(ckpt_config, "auto_resume", None)
if self.auto_resume is None: # (legacy): Try Compatible with old interfaces
self.auto_resume = auto_resume_sanity_check(ckpt_config)
if self.auto_resume:
self.load_ckpt_info = self.query_lastest_ckpt()
if self.stop_file_path is None and gpc.is_rank_for_log():
logger.warning("no set stop_file_path, quit_signal_handler is disable")
# convert to internal representation
if self.load_ckpt_info:
assert (
"path" in self.load_ckpt_info
and "content" in self.load_ckpt_info
and "ckpt_type" in self.load_ckpt_info
), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
# replace load_ckpt
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
# test storage setting is ok.
if self.enable_save_ckpt:
self.try_ping_storage()
def quit_signal_handler(self, train_state) -> bool:
"""
@ -334,7 +540,7 @@ class CheckpointManager:
Returns:
bool: whether to quit.
"""
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT
if self.stop_file_path is None:
return now_break, now_save_ckpt, save_type
@ -365,24 +571,29 @@ now step_count is {train_state.step_count}",
return now_break, now_save_ckpt, save_type
def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return False
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0:
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
if save_ckpts is False:
save_ckpts = singal_save_ckpts
save_type = singal_save_type
return save_ckpts, save_type, now_break
def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return False
save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)
if save_ckpts:
# Wait for the previous round of asynchronous upload storage to complete.
self.storage_manager.wait()
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT:
# Snapshot number, with only two snapshots written alternately.
self.snapshot_counter = (self.snapshot_counter + 1) % 2
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
@ -412,7 +623,7 @@ now step_count is {train_state.step_count}",
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
"""
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
if len(ckpt_list) == 0:
if ckpt_list is None or len(ckpt_list) == 0:
return None, None
max_normal_step = 0
@ -435,14 +646,16 @@ now step_count is {train_state.step_count}",
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
max_step_0, max_step_1 = 0, 0
for ckpt in ckpt_list_1:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
for ckpt in ckpt_list_2:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
if ckpt_list_1:
for ckpt in ckpt_list_1:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
if ckpt_list_2:
for ckpt in ckpt_list_2:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
snap_step = max(max_step_0, max_step_1)
@ -452,11 +665,12 @@ now step_count is {train_state.step_count}",
def query_latest_snapshot_step_local(self):
max_step, max_step_path = 0, None
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
for fn in files:
fn = fn.strip("/")
if fn.endswith(".step"):
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
# We assume that both internlm ckpt and snapshot ckpt will store the '.step' file
# as an integrity flag.
step = int(fn.rsplit(".", maxsplit=1)[0])
if max_step < step:
@ -466,100 +680,55 @@ now step_count is {train_state.step_count}",
return max_step_path, max_step
def query_lastest_ckpt(self):
latest_checkpoint = None
latest_ckpt, step = None, -1
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
if self.save_ckpt_folder:
if self.save_ckpt_folder.startswith("boto3"):
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
elif self.save_ckpt_folder.startswith("local"):
latest_checkpoint, step = self.query_latest_snapshot_step_local()
else:
latest_checkpoint, step = None, 0
backend, _ = try_get_storage_backend(self.save_ckpt_folder)
if backend == "boto3":
latest_ckpt, step = self.query_latest_snapshot_step_boto3()
if latest_ckpt and not latest_ckpt.startswith("boto3:"):
latest_ckpt = ":".join(["boto3", latest_ckpt])
elif backend == "local":
latest_ckpt, step = self.query_latest_snapshot_step_local()
if latest_ckpt and not latest_ckpt.startswith("local:"):
latest_ckpt = ":".join(["local", latest_ckpt])
if latest_checkpoint is not None:
if gpc.is_rank_for_log():
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
send_alert_message(
address=self.feishu_address,
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
)
else:
if gpc.is_rank_for_log():
send_alert_message(
address=self.feishu_address,
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
)
if gpc.is_rank_for_log():
logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")
return latest_checkpoint
return dict(path=latest_ckpt, content=("all",), ckpt_type="internlm")
def try_load_model(self, current_time=""):
model_load_path = None
def try_resume_training(self, train_state: TrainState, current_time=""):
if self.load_ckpt_folder and self.load_model_only_folder:
raise ValueError(
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
if you only need to load model weights (for example starting an SFT task for the first time), \
set load_model_only_folder path, if you need to resume training from ckpt, \
set load_ckpt_folder or use default value \
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
)
if self.load_ckpt_folder:
if gpc.is_rank_for_log():
logger.info(
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = self.load_ckpt_folder
elif self.load_model_only_folder:
if gpc.is_rank_for_log():
logger.info(
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = self.load_model_only_folder
else:
if self.load_ckpt_info is None or self.load_ckpt_info["path"] is None:
if gpc.is_rank_for_log():
logger.info(
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
else:
load_path = self.load_ckpt_info["path"]
load_content = self.load_ckpt_info["content"]
load_type = self.load_ckpt_info["ckpt_type"]
# Loading model weights must be done before zero is initialized.
if model_load_path is not None:
load_model_checkpoint(folder=model_load_path, model=self.model)
load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type)
load_content_str = load_func(self, self.load_ckpt_info, train_state)
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
"""Attempt to restore the training state of the last ckpt.
# If we only load model weight, we need rewrite zero optim's fp32 buffer.
if load_content.only_load(CheckpointLoadContent.MODEL) and isinstance(self.optimizer, HybridZeroOptimizer):
reload_zero_fp32_buff(self.optimizer)
Args:
lr_scheduler (_LRScheduler): lr_scheduler object.
optimizer (Optimizer): optimizer object.
lr (float): learning rate.
train_state (dict): traing states.
train_dl (DataLoader): traning dataloader object
"""
if self.load_ckpt_folder is not None:
# load optimzier states.
if self.load_optimizer:
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
# load lr scheduler states.
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
# load training states.
load_context(self.load_ckpt_folder, train_dl, train_state)
# load dataloader sampler states.
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
if hasattr(train_state, "data_state_dict"):
train_dl.dataset.load_state_dict(
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
if gpc.is_rank_for_log():
logger.info(f"load_ckpt_info : {self.load_ckpt_info}")
logger.info(
f"===========Resume training from `{load_path}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
if load_content_str:
logger.info(f"===========Load contents are: {load_content_str}")
@llm_timeout(func_name="save_checkpoint")
def save_checkpoint(
self,
folder,
@ -600,8 +769,10 @@ set load_ckpt_folder or use default value \
)
if gpc.is_rank_for_log():
scheduler_states = scheduler.state_dict()
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
if scheduler:
scheduler_states = scheduler.state_dict()
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
@ -631,3 +802,12 @@ set load_ckpt_folder or use default value \
def set_save_folder(self, folder, step):
self.storage_manager.latest_save_folder = folder
self.storage_manager.latest_save_step = step
def try_ping_storage(self):
if gpc.get_global_rank() % 8 == 0:
buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
self.storage_manager.save(test_fn, buff)
self.storage_manager.wait()
self.storage_manager.load(test_fn)
del buff

View File

@ -46,12 +46,12 @@ def get_fns(fp: str):
return storage_manager.get_fns(fp)
def llm_load(fp: str, *args, **kwargs):
return storage_manager.load(fp, *args, **kwargs)
def llm_load(fp: str, **kwargs):
return storage_manager.load(fp, **kwargs)
def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
def llm_save(save_path: str, saved_obj: Any, **kwargs):
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
class StorageClient:
@ -63,19 +63,23 @@ class StorageClient:
self.handler = handler
@staticmethod
def load(client, load_path: str, *args, **kwargs):
def load(*args, **kwargs):
raise NotImplementedError
@staticmethod
def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
def sync_upload_fileobj(*args, **kwargs):
raise NotImplementedError
@staticmethod
def assert_fp_exists(client):
def async_upload_fileobj(*args, **kwargs):
raise NotImplementedError
@staticmethod
def get_fns(client):
def assert_fp_exists(*args, **kwargs):
raise NotImplementedError
@staticmethod
def get_fns(*args, **kwargs):
raise NotImplementedError
@ -92,40 +96,65 @@ class Boto3MetaInfo:
async_upload_fn: callable,
local_nvme_path=None,
) -> None:
self.is_async = is_async
# all need info.
self.client = handler
self.bucket_name = bucket_name
self.endpoint = endpoint
self.file_path = file_path
self.async_upload_fn = async_upload_fn
# only save need info.
self.local_nvme_path = local_nvme_path
self.is_async = is_async
self.endpoint = endpoint
self.async_upload_fn = async_upload_fn
def __str__(self) -> str:
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
local_nvme_path: {self.local_nvme_path}"
@staticmethod
def unpack_boto3_save_meta(meta):
if meta.is_async:
return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path
else:
return meta.client, meta.bucket_name, meta.file_path
@staticmethod
def unpack_boto3_nosave_meta(meta):
return meta.client, meta.bucket_name, meta.file_path
class LocalMetaInfo:
"""Local meta info for save/load etc."""
def __init__(self, handler: StorageClient, dest_path: str) -> None:
self.is_async = False
self.client = handler
self.dest_path = dest_path
def __init__(self, file_path: str) -> None:
self.file_path = file_path
self.async_upload_fn = None
self.is_async = False
@staticmethod
def unpack_local_save_meta(meta):
return (meta.file_path,)
@staticmethod
def unpack_local_nosave_meta(meta):
return (meta.file_path,)
def unpack_meta(meta):
args = []
is_async = meta.is_async
for k, v in meta.__dict__.items():
if k in ("endpoint", "async_upload_fn", "is_async"):
continue
if not is_async and k in ("local_nvme_path",):
continue
args.append(v)
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
if isinstance(meta, Boto3MetaInfo):
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
elif isinstance(meta, LocalMetaInfo):
return LocalMetaInfo.unpack_local_save_meta(meta)
else:
raise ValueError(f"unkonwn meta info: {type(meta)}")
return args
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
if isinstance(meta, Boto3MetaInfo):
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
elif isinstance(meta, LocalMetaInfo):
return LocalMetaInfo.unpack_local_nosave_meta(meta)
else:
raise ValueError(f"unkonwn meta info: {type(meta)}")
def compute_file_md5_by_chunk(file_name: str):
@ -136,6 +165,22 @@ def compute_file_md5_by_chunk(file_name: str):
return hash_md5.hexdigest()
def try_get_storage_backend(path: str):
sre = path.split(":", maxsplit=1)
if len(sre) == 1:
if path.startswith("s3:"):
backend = "boto3"
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
else:
backend = "local"
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
return backend, sre
else:
return sre[0], sre[1] # (backend_prefix, splited_path)
class Boto3Client(StorageClient):
"""
Boto3Client
@ -189,13 +234,11 @@ class Boto3Client(StorageClient):
)
@staticmethod
def sync_upload_fileobj(
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
): # pylint: disable=W0613
def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs):
assert saved_obj is not None, "saved_obj is None!"
try:
with io.BytesIO() as f:
torch.save(saved_obj, f, *args, **kwargs)
torch.save(saved_obj, f, **kwargs)
f.seek(0)
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
except handler.botocore.exceptions.EndpointConnectionError as exc:
@ -204,14 +247,7 @@ class Boto3Client(StorageClient):
) from exc
@staticmethod
def load(
handler,
bucket_name: str,
fp: str,
local_nvme_path: str, # pylint: disable=W0613
*args,
**kwargs,
) -> Dict:
def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict:
"""
Args:
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
@ -220,7 +256,7 @@ class Boto3Client(StorageClient):
with io.BytesIO() as f:
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
f.seek(0)
states = torch.load(f, *args, **kwargs)
states = torch.load(f, **kwargs)
except handler.botocore.exceptions.EndpointConnectionError as exc:
raise RuntimeError(
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
@ -228,24 +264,37 @@ class Boto3Client(StorageClient):
return states
@staticmethod
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
@staticmethod
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
if "Contents" in re:
return len(list(re["Contents"])) > 0
else:
return False
@staticmethod
def get_fns(handler, bucket_name: str, fp: str):
"""
Ref: https://stackoverflow.com/questions/54314563/
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
"""
paginator = handler.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
folder_name_list = []
for page in pages:
if "Contents" in page:
for obj in page["Contents"]:
pth: str = obj["Key"]
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
if Boto3Client.is_fp_exists(handler, bucket_name, fp):
paginator = handler.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
folder_name_list = []
for page in pages:
if "Contents" in page:
for obj in page["Contents"]:
pth: str = obj["Key"]
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
else:
if gpc.is_rank_for_log():
logger.warning(f"'{fp}' not found!")
return None
@staticmethod
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
@ -273,37 +322,35 @@ class LocalClient(StorageClient):
super().__init__(None)
@staticmethod
def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
assert isinstance(handler, LocalClient)
def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs):
assert saved_obj is not None
fp_dirname = os.path.dirname(fp)
if not os.path.exists(fp_dirname):
os.makedirs(fp_dirname, exist_ok=True)
torch.save(saved_obj, fp, *args, **kwargs)
torch.save(saved_obj, fp, **kwargs)
@staticmethod
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
assert isinstance(handler, LocalClient)
assert os.path.exists(fp), f"{fp} is not found!"
with open(fp, "rb") as f:
states = torch.load(f, *args, **kwargs)
def load(load_path: str, **kwargs):
assert os.path.exists(load_path), f"{load_path} is not found!"
with open(load_path, "rb") as f:
states = torch.load(f, **kwargs)
return states
@staticmethod
def assert_fp_exists(handler, folder):
assert isinstance(handler, LocalClient)
def assert_fp_exists(folder):
assert os.path.exists(folder), folder
@staticmethod
def get_fns(handler, folder):
assert isinstance(handler, LocalClient)
assert os.path.exists(folder), f"folder '{folder}' not exists!"
fns = os.listdir(folder)
return fns
def get_fns(folder):
if not os.path.exists(folder):
if gpc.is_rank_for_log():
logger.warning(f"'{folder}' not found!")
return None
else:
return os.listdir(folder)
@staticmethod
def delete_obj(handler, fp: str):
assert isinstance(handler, LocalClient)
def delete_obj(fp: str):
if not os.path.isdir(fp):
os.remove(fp)
@ -327,7 +374,10 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
assert match is not None, f"url '{fp}' is not a valid boto3 url"
bucket_name, endpoint = match.group(1), match.group(2)
endpoint = "http://" + endpoint + ":80"
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
if is_async:
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
else:
tmp_step_file = None
return Boto3MetaInfo(
is_async=is_async,
handler=None,
@ -341,7 +391,7 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(None, fp)
return LocalMetaInfo(fp)
def get_mount_point_free_size(path: str):
@ -427,7 +477,7 @@ class StorageManager(metaclass=SingletonMeta):
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]:
"""
example:
local:/path/to/checkpoint
@ -436,17 +486,14 @@ class StorageManager(metaclass=SingletonMeta):
Args:
path (str): _description_
"""
try:
backend, path = path.split(":", maxsplit=1)
except Exception as exc:
raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc
backend, path = try_get_storage_backend(path)
init_args = (None,)
if backend == "local":
meta_info = get_local_meta(path)
backend_key = backend
elif backend == "boto3":
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode)
backend_key = backend + ":" + meta_info.endpoint
init_args = (meta_info.endpoint,)
if (
@ -474,17 +521,22 @@ class StorageManager(metaclass=SingletonMeta):
def assert_fp_exists(self, folder) -> None:
meta = self._get_client(path=folder)
meta.client.assert_fp_exists(*unpack_meta(meta))
meta.client.assert_fp_exists(*unpack_nosave_meta(meta))
def get_fns(self, folder) -> List[str]:
meta = self._get_client(path=folder)
return meta.client.get_fns(*unpack_meta(meta))
return meta.client.get_fns(*unpack_nosave_meta(meta))
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
meta = self._get_client(path=save_path)
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
if async_upload is None:
async_upload = self.async_mode
if not save_path.startswith("boto3:"):
async_upload = False
meta = self._get_client(save_path, async_upload)
if async_upload:
assert (
self.tmp_local_folder
@ -492,22 +544,22 @@ class StorageManager(metaclass=SingletonMeta):
tmp_step_file = meta.local_nvme_path
self._to_be_del_files.append(tmp_step_file)
with open(tmp_step_file, "wb") as f:
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta))
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
self.async_task_peeding = True
else:
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs)
self.upload_count += 1
def load(self, load_path: str, *args, **kwargs) -> Any:
def load(self, load_path: str, **kwargs) -> Any:
self.wait()
meta = self._get_client(path=load_path)
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
def delete_obj(self, fp: str):
meta = self._get_client(path=fp)
meta.client.delete_obj(*unpack_meta(meta))
meta.client.delete_obj(*unpack_nosave_meta(meta))
def _del_tmp_folder(self):
for fp in self._to_be_del_files:
@ -594,23 +646,24 @@ class StorageManager(metaclass=SingletonMeta):
if gpc.is_rank_for_log():
self.upload_count += 1
if self.async_mode:
if self.async_mode and self.latest_save_folder:
self.save(
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
saved_obj=dict({"step": self.latest_save_step}),
to_save_obj=dict({"step": self.latest_save_step}),
async_upload=False,
)
self.latest_save_folder = None
storage_manager: StorageManager = None
def init_storage_manager(ckpt_config):
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload):
global storage_manager
storage_manager = StorageManager(
ckpt_config.enable_save_ckpt,
tmp_local_folder=ckpt_config.async_upload_tmp_folder,
async_mode=ckpt_config.async_upload,
enable_save_ckpt,
tmp_local_folder=async_upload_tmp_folder,
async_mode=async_upload,
)

View File

@ -1,4 +1,13 @@
import datetime
import os
import signal
import socket
import traceback
from functools import wraps
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
class Timeout:
@ -24,3 +33,81 @@ class Timeout:
def __exit__(self, error_type, value, traceback):
signal.alarm(0)
ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None)
timeout_threshold_dict = {
"initialize_distributed_env": 120,
"nopp_forward_backward_step": 360,
"initialize_model": 10,
"initialize_optimizer": 20,
"optim_step": 30,
"get_train_data_loader": 600,
"get_validation_data_loader": 60,
"load_new_batch": 10,
"record_current_batch_training_metrics": 10,
"save_checkpoint": 1200,
"interleaved_forward_backward_step": 600,
"nointerleaved_forward_backward_step": 600,
}
if ENABLE_TIMEOUT is not None:
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60))))
else:
timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0)
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800)
def try_get_gpc_rank():
try:
from internlm.core.context import global_context as gpc
rank = gpc.get_global_rank()
except: # noqa # pylint: disable=bare-except
rank = "unknown"
return f"host-{socket.gethostname()}-rank-{rank}"
def llm_timeout(seconds=0, func_name=None):
"""timeout decorator, Note that this decorator cannot be reentrant,
otherwise the signal will be reset.
Args:
seconds (int, optional): timeout threshold. Defaults to 300.
func_name (str, optional): the func who is been waited to timeout.
"""
def decorator(func):
nonlocal func_name
if func_name is None:
func_name = func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
def _handle_timeout(signum, frame):
raise TimeoutError
nonlocal seconds
seconds = timeout_threshold_dict.get(func_name, seconds)
if seconds > 0:
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}")
raise e
finally:
signal.alarm(0)
return result
return wrapper
return decorator

0
tests/__init__.py Normal file
View File

View File

@ -0,0 +1,181 @@
import os
import shutil
from subprocess import PIPE, STDOUT, Popen
import pytest
import torch
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.utils.common import SingletonMeta
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
OSS_IP = os.environ["OSS_IP"]
USER = os.environ["USER"]
JOB_NAME = "CI_TEST"
LOCAL_SAVE_PATH = "local:local_ckpt"
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
ASYNC_TMP_FOLDER = "./async_tmp_folder"
# 1B
init_config = Config(
dict(
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
model_type="INTERNLM",
adam=dict(
lr=1e-4,
),
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
checkpoint=False,
num_attention_heads=2,
embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=1024,
num_layers=2,
mlp_ratio=1,
apply_post_layer_norm=False,
dtype=torch.bfloat16,
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1,
),
resume_tb_folder="",
tensorboard_folder="",
)
)
def init_naive_model():
# let MODEL_INITIALIZER to work
import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.registry import MODEL_INITIALIZER
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model))
model = NaiveAMPModel(
model=model,
output_to_fp32=False,
dtype=torch.bfloat16,
sync_buffer=False,
)
return model
def init_naive_optim(model):
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": 0.01}],
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
)
return naive_optimizer
def init_hybrid_optim(model):
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": 0.01}],
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
)
optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=Config(
dict(
fp16=dict(
initial_scale=2**16,
min_scale=1,
growth_interval=1000,
),
growth_factor=2,
backoff_factor=0.5,
max_scale=2**24,
hysteresis=2,
)
),
zero_cfg=Config(
dict(
overlap_sync_grad=False,
overlap_sync_param=False,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
)
),
param_bcast_sync_handler=None,
)
return optimizer
@pytest.fixture(autouse=True, scope="function")
def reset_singletons():
SingletonMeta._instances = {}
def reset_seed():
from internlm.core.context.random import _SEED_MANAGER
_SEED_MANAGER.reset()
@pytest.fixture(scope="module")
def init_dist_and_model(rank=0, world_size=1):
from internlm.initialize import initialize_distributed_env
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377"
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
# setup
print("set up", flush=True)
model = init_naive_model()
# opim = init_naive_optim(model)
opim = init_hybrid_optim(model)
yield model, opim
# teardown
del model, opim
print("teardown", flush=True)
gpc.destroy()
reset_seed()
def enter_flag(text):
print(f"{text} begin!", flush=True)
yield
print(f"{text} end!", flush=True)
def del_tmp_file():
try:
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
except FileNotFoundError:
pass
try:
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
except FileNotFoundError:
pass
try:
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
results, presults = "", ""
for line in iter(output.stdout.readline, b""):
results += str(line.rstrip())
presults += line.rstrip().decode() + "\n"
print(presults, flush=True)
except FileNotFoundError:
pass

View File

@ -0,0 +1,247 @@
import os
import pytest
import torch
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import TrainState
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.utils.common import SingletonMeta
from internlm.utils.model_checkpoint import CheckpointManager
from internlm.utils.storage_manager import wait_async_upload_finish
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ASYNC_TMP_FOLDER,
BOTO_SAVE_PATH,
LOCAL_SAVE_PATH,
del_tmp_file,
init_dist_and_model,
reset_singletons,
)
TOTAL_STEP = 6
CKPT_EVERY = 4
SNPASHOT_EVERY = 2
ckpt_config_list = [
# Old interface format
dict(
enable_save_ckpt=True,
save_ckpt_folder=BOTO_SAVE_PATH,
load_optimizer=True,
checkpoint_every=CKPT_EVERY,
async_upload=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
oss_snapshot_freq=SNPASHOT_EVERY,
stop_file_path=None,
load_model_only_folder=None,
load_given_ckpt=False,
load_ckpt_folder=None,
is_old_api=True,
),
# Old interface format
dict(
enable_save_ckpt=True,
save_ckpt_folder=LOCAL_SAVE_PATH,
load_optimizer=True,
checkpoint_every=CKPT_EVERY,
async_upload=False,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
oss_snapshot_freq=SNPASHOT_EVERY,
stop_file_path=None,
load_model_only_folder=None,
load_given_ckpt=False,
load_ckpt_folder=None,
is_old_api=True,
),
# New interface format
dict(
enable_save_ckpt=True,
save_ckpt_folder=BOTO_SAVE_PATH,
checkpoint_every=CKPT_EVERY,
async_upload=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
oss_snapshot_freq=SNPASHOT_EVERY,
stop_file_path=None,
is_old_api=False,
auto_resume=True,
),
dict(
enable_save_ckpt=True,
save_ckpt_folder=LOCAL_SAVE_PATH,
checkpoint_every=CKPT_EVERY,
async_upload=False,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
oss_snapshot_freq=SNPASHOT_EVERY,
stop_file_path=None,
load_ckpt_folder=None,
is_old_api=False,
auto_resume=True,
),
]
def overwrite_optim_state(optim, set_value):
if isinstance(optim, HybridZeroOptimizer):
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
p.data.fill_(set_value)
for group_id in range(len(optim._fp16_param_groups)):
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
rank=optim._zero_local_rank, group_id=group_id
)
fp16_p.fill_(set_value)
else:
for group in optim.param_groups:
for p in group["params"]:
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
p.data.fill_(set_value)
def compare_optim_state(optim1, optim2):
re = True
if isinstance(optim1, HybridZeroOptimizer):
fp32_buff1 = optim1._fp32_flat_param_groups_of_current_rank
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
re &= group_id_1 == group_id_2
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
else:
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
for p1, p2 in zip(group1["params"], group2["params"]):
re &= torch.equal(p1, p2)
return re
def compare_optim_value(optim, value):
re = True
if isinstance(optim, HybridZeroOptimizer):
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
for group_id in range(len(optim._fp16_param_groups)):
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
rank=optim._zero_local_rank, group_id=group_id
)
re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
else:
for group in optim.param_groups:
for p in group["params"]:
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
return re
def overwrite_model_value(model, value):
for p in model.parameters():
# p.copy_(torch.full_like(p, value, dtype=p.dtype))
p.data.fill_(value)
def compare_model_value(model, value):
re = True
for p in model.parameters():
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
return re
@pytest.fixture(scope="function")
def del_tmp():
del_tmp_file()
yield
del_tmp_file()
@pytest.mark.usefixtures("del_tmp")
@pytest.mark.usefixtures("reset_singletons")
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType
ckpt_config = Config(ckpt_config)
assert ckpt_config.checkpoint_every < TOTAL_STEP
assert ckpt_config.oss_snapshot_freq < TOTAL_STEP
model, opim = init_dist_and_model
train_state = TrainState(gpc.config, None)
if isinstance(opim, HybridZeroOptimizer):
print("Is HybridZeroOptimizer!", flush=True)
else:
print("Is naive Adam!", flush=True)
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
latest_ckpt_step = None
for i in range(TOTAL_STEP + 1):
overwrite_model_value(model, i)
overwrite_optim_state(opim, i)
train_state.batch_count = i
train_state.step_count += 1
save_ckpts, _, _ = ckpt_mm.is_now_to_save_ckpt(train_state)
if save_ckpts:
latest_ckpt_step = i
ckpt_mm.try_save_checkpoint(train_state)
wait_async_upload_finish()
latest_ckpt_info = ckpt_mm.query_lastest_ckpt()
assert latest_ckpt_info is not None
latest_ckpt = latest_ckpt_info["path"]
if ckpt_mm.save_ckpt_folder.startswith("local"):
assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt
else:
assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt
del ckpt_mm
SingletonMeta._instances = {}
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
ckpt_mm.try_resume_training(train_state)
assert latest_ckpt_step == 5
assert train_state.step_count == 6
assert train_state.batch_count == 6
assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0]
assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0]
if ckpt_mm.save_ckpt_folder.startswith("local:"):
ckpt_mm.load_ckpt_info = dict(
path=os.path.join(LOCAL_SAVE_PATH, "4"),
content=CheckpointLoadMask(("all",)),
ckpt_type=CheckpointLoadType.INTERNLM,
)
else:
ckpt_mm.load_ckpt_info = dict(
path=os.path.join(BOTO_SAVE_PATH, "4"),
content=CheckpointLoadMask(("all",)),
ckpt_type=CheckpointLoadType.INTERNLM,
)
ckpt_mm.try_resume_training(train_state)
assert train_state.step_count == 4
assert train_state.batch_count == 4
assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0]
assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0]
@pytest.mark.usefixtures("del_tmp")
@pytest.mark.usefixtures("reset_singletons")
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
def test_ckpt_mm_ping(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
ckpt_config = Config(ckpt_config)
model, opim = init_dist_and_model
SingletonMeta._instances = {}
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
ckpt_mm.try_ping_storage()
if __name__ == "__main__":
pytest.main()

View File

@ -0,0 +1,89 @@
import os
import pytest
import torch
from internlm.core.context.parallel_context import Config
from internlm.initialize.launch import get_config_value
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ASYNC_TMP_FOLDER,
BOTO_SAVE_PATH,
LOCAL_SAVE_PATH,
del_tmp_file,
init_dist_and_model,
reset_singletons,
)
ASYNC_TMP_FOLDER = "./async_tmp_folder"
ckpt_config_list = [
# async boto
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
async_upload=True,
save_folder=BOTO_SAVE_PATH,
test_id=0,
),
# sync local
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=None,
async_upload=False,
save_folder=LOCAL_SAVE_PATH,
test_id=1,
),
# sync boto
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=None,
async_upload=False,
save_folder=BOTO_SAVE_PATH,
test_id=2,
),
# async local
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
async_upload=True,
save_folder=LOCAL_SAVE_PATH,
test_id=3,
),
]
@pytest.fixture(scope="function")
def del_tmp():
del_tmp_file()
yield
del_tmp_file()
@pytest.mark.usefixtures("del_tmp")
@pytest.mark.usefixtures("reset_singletons")
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
from internlm.utils.storage_manager import (
check_folder,
get_fns,
init_storage_manager,
llm_load,
llm_save,
wait_async_upload_finish,
)
ckpt_config = Config(ckpt_config)
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
async_upload_tmp_folder = get_config_value(ckpt_config, "async_upload_tmp_folder", False)
async_upload = get_config_value(ckpt_config, "async_upload", False)
init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
tobj = torch.rand(64, 64)
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
llm_save(save_fn, tobj)
if ckpt_config.test_id == 0:
wait_async_upload_finish()
check_folder(save_fn)
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
load_obj = llm_load(save_fn, map_location="cpu")
assert 0 == ((load_obj != tobj).sum())

View File

@ -0,0 +1,119 @@
import fcntl
import os
import time
from multiprocessing import Process
import pytest
import torch
import torch.distributed as dist
os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1" # noqa # pylint: disable=wrong-import-position
os.environ["NCCL_TIMEOUT"] = "5"
from internlm.utils.timeout import llm_timeout
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
init_config,
)
WORLD_SIZE = 2
@llm_timeout(2, "fake_timeout_func")
def fake_timeout_func():
time.sleep(10)
@llm_timeout(10, "nccl_timeout_func")
def nccl_timeout_func(rank):
# see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880
# 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication.
buff = torch.ones([64, 64]).cuda(rank)
dist.all_reduce(buff) # lazy communicator init
torch.cuda.synchronize()
if rank == 0:
dist.all_reduce(buff)
torch.cuda.synchronize() # main thread will hang at here.
else:
time.sleep(9999)
@llm_timeout(10, "try_file_lock")
def try_file_lock(rank, stop_file_path):
if rank == 1:
time.sleep(5)
with open(stop_file_path, "r", encoding="utf-8") as f:
fcntl.flock(f, fcntl.LOCK_EX) # rank 1 hang.
if rank == 0:
time.sleep(99999) # rank 0 hang.
f.seek(0)
f.read()
fcntl.flock(f, fcntl.LOCK_UN)
def local_timeout(rank, _):
try:
fake_timeout_func()
except TimeoutError as e:
print(f"local_timeout, rank:{rank}, e:{e}", flush=True)
else:
assert False, "It should timeout!"
def gpc_timeout(rank, world_size):
from internlm.initialize import initialize_distributed_env
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377"
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
try:
nccl_timeout_func(rank)
except TimeoutError as e:
print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True)
time.sleep(5) # wait rank 0 to be killed
else:
time.sleep(5) # give some time to let Watchdog kill rank 0.
assert False, "It should timeout!"
def file_lock_timeout(rank, _, stop_file_path):
if rank == 0:
with open(stop_file_path, "w"):
pass
try:
try_file_lock(rank, stop_file_path)
except TimeoutError as e:
print(e, flush=True)
else:
assert False, "It should timeout!"
finally:
if rank == 0:
os.remove(stop_file_path)
timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")]
@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list)
def test_timeout(timeout_func_and_args):
timeout_func, world_size, other_args = timeout_func_and_args
procs = []
for i in range(world_size):
if other_args is None:
args = (i, world_size)
else:
args = (i, world_size, other_args)
proc = Process(target=timeout_func, args=args)
proc.start()
procs.append(proc)
for proc in procs:
proc.join(15)
if proc.is_alive():
proc.terminate()
proc.join()

View File

@ -35,7 +35,7 @@ from internlm.utils.common import (
parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.gputest import bench_gpu, bench_net
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import CheckpointManager
@ -73,7 +73,6 @@ def main(args):
total_steps = gpc.config.data.total_steps
valid_every = gpc.config.data.valid_every
label_smoothing = gpc.config.loss.label_smoothing
lr = gpc.config.adam.lr
get_tflops_func = partial(
get_megatron_flops,
@ -96,21 +95,11 @@ def main(args):
# initialize customed llm logger
uniscale_logger = initialize_llm_logger(start_time=current_time)
# initialize and resume train state
train_state = TrainState(gpc.config)
# initialize model
model = initialize_model()
with open(args.config, "r") as f:
config_lines = f.readlines()
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.alert_address,
)
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
@ -118,15 +107,25 @@ def main(args):
# initialize the train and validation data loader
train_dl, dataset_types = get_train_data_loader(num_worker=4)
val_dls = get_validation_data_loader()
train_state.init_batch_sampler(train_dl)
# Loading model weights must be done before zero is initialized.
ckpt_manager.try_load_model(current_time)
# initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_dl=train_dl,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
)
# Loading other persistent training states.
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
ckpt_manager.try_resume_training(train_state, current_time)
# initialize customed llm writer
writer = Writer(
@ -195,11 +194,7 @@ def main(args):
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0:
torch.cuda.empty_cache()
bench_gpu()
bench_net()
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
start_time = time.time()
timer("one-batch").start()
@ -241,7 +236,7 @@ def main(args):
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address,
address=gpc.config.monitor.alert.feishu_alert_address,
message=f"Warning: skip parameter update at step {batch_count}.",
)
@ -301,11 +296,15 @@ if __name__ == "__main__":
assert hasattr(gpc, "config") and gpc.config is not None
# initialize monitor manager context
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
with initialize_monitor_manager(
job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
):
try:
main(args)
except Exception:
logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
)
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
mm.monitor_exception(
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
)

View File

@ -1 +1 @@
0.1.0
0.2.0