diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 1f1993f..cbcd0e5 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -1,4 +1,5 @@ JOB_NAME = "7b_train" +DO_ALERT = False SEQ_LEN = 2048 HIDDEN_SIZE = 4096 @@ -22,13 +23,16 @@ CHECKPOINT_EVERY = 50 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). - # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights. - load_optimizer=True, # Wheter to load optimizer states when continuing training. + # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), + load_ckpt_folder="local:llm_ckpts/", + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) @@ -52,6 +56,8 @@ data = dict( min_length=50, # train_folder=TRAIN_FOLDER, # valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=10, + diag_outlier_ratio=1.1, ) grad_scaler = dict( @@ -147,3 +153,12 @@ parallel = dict( cudnn_deterministic = False cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + ), +) diff --git a/doc/code-docs/locales/en/LC_MESSAGES/initialize.po b/doc/code-docs/locales/en/LC_MESSAGES/initialize.po index 14955c0..c3ea055 100644 --- a/doc/code-docs/locales/en/LC_MESSAGES/initialize.po +++ b/doc/code-docs/locales/en/LC_MESSAGES/initialize.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: InternLM \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-09-07 14:15+0800\n" +"POT-Creation-Date: 2023-09-08 15:32+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: zh_CN\n" @@ -19,26 +19,29 @@ msgstr "" "Content-Transfer-Encoding: 8bit\n" "Generated-By: Babel 2.12.1\n" -#: ../../source/initialize.rst:2 b829330eebd24620b745072bbfc26c98 +#: ../../source/initialize.rst:2 msgid "训练构建" msgstr "Training Setup" -#: ../../source/initialize.rst:7 8c8472b4647a4de8998d75b9ec6f09eb +#: ../../source/initialize.rst:7 msgid "命令行参数解析" msgstr "Argument Parsing" -#: ../../source/initialize.rst:8 f74176fa4aee4bbfaf989ffab9283ee7 +#: ../../source/initialize.rst:9 +#, fuzzy msgid "" "InternLM 使用 `argparse `_" -" 库来向InternLM运行时提供命令行参数配置。用户可 使用 " +" 库来向InternLM运行时提供命令行参数配置。用户可使用 " "``internlm.initialize.get_default_parser()`` 来获取 InternLM " "的默认解析器,其中包含一些内置参数,用户可以向此解析器添加自定义参数。" msgstr "" -"InternLM uses the `argparse `_ library to supply commandline " -"configuration to the InternLM runtime. Use ``internlm.initialize.get_default_parser()`` to get InternLM's default " -"parser with some builtin arguments, users can add custom parameters to this parser." +"InternLM uses the `argparse " +"`_ library to supply " +"commandline configuration to the InternLM runtime. Use " +"``internlm.initialize.get_default_parser()`` to get InternLM's default " +"parser with some builtin arguments, users can add custom parameters to " +"this parser." -#: 9930855b85bf41ed8712fc40e1e034f7 #: internlm.initialize.launch.get_default_parser:1 of msgid "" "Reads user command line and uses an argument parser to parse the input " @@ -46,9 +49,6 @@ msgid "" " local rank, backend for torch.distributed." msgstr "" -#: 015003b013e346bea15b4514f2001a25 544472c2ce3c43bfb59317083c6b55c9 -#: 7ee60ba1a92a4b9e8174049fb498a4f0 bca7c66f1a5a4517958bcea1e09d5d10 -#: f5cbe452ae694c7884ac4596a7735bf6 #: internlm.initialize.initialize_trainer.initialize_trainer #: internlm.initialize.launch.get_default_parser #: internlm.train.training_internlm.get_train_data_loader @@ -57,55 +57,50 @@ msgstr "" msgid "返回" msgstr "" -#: 9b04c3d6b98b44ee89f800b71e8d80a9 #: internlm.initialize.launch.get_default_parser:4 of msgid "" "Returns the parser with the default arguments, the user may add " "customized arguments into this parser." msgstr "" -#: 147005b197e64c4b9a96a7cfe78045bc 3634f79c9aa547a48eb3fd7f150deb51 -#: d3f0aa4143c84b719cd0b53170dd86c1 #: internlm.initialize.initialize_trainer.initialize_trainer #: internlm.initialize.launch.get_default_parser #: internlm.train.training_internlm.initialize_model of msgid "返回类型" msgstr "" -#: ../../source/initialize.rst:25 db2bf9d3ff81483dbf218e63dd4bbbe4 +#: ../../source/initialize.rst:25 msgid "模型初始化" msgstr "Model Initialization" -#: 5c2e33e254d4495fbc4b0226aac1fddb #: internlm.train.training_internlm.initialize_model:1 of msgid "Initialize model with Automatic Mixed Precision." msgstr "" -#: c1254615508542b680daf73374844f9e #: internlm.train.training_internlm.initialize_model:3 of msgid "The neural network model to be trained or evaluated." msgstr "" -#: ../../source/initialize.rst:29 b9867771b9da40cd8f3a55ee5ab95f65 +#: ../../source/initialize.rst:29 msgid "InternLM 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下:" msgstr "" "InternLM uses the field ``model_type`` and ``model`` in the config file " "to control model initialization process. An example model initialization " "configuratio" -#: ../../source/initialize.rst:57 984a38d7f63949ecbb0d8b2ef3459d57 +#: ../../source/initialize.rst:57 msgid "字段 ``model_type`` 指明了要初始化的模型类型" msgstr "" "The field ``model_type`` specifics the model type has been registered and" " to be initialized." -#: ../../source/initialize.rst:58 9f04ad0f145f4e40bc75a3ef45c7a59d +#: ../../source/initialize.rst:58 msgid "字段 ``model`` 中的参数指定了在模型初始化过程中的参数设置" msgstr "" "The parameters in field ``model`` specific the configuration settings " "during model initialization." -#: ../../source/initialize.rst:60 d7780e355bb6429bb5151d9a0e6d7e36 +#: ../../source/initialize.rst:60 msgid "" "值得注意的是,用户可以定义新的模型类型,并使用装饰器 ``@MODEL_INITIALIZER.register_module`` " "注册模型的初始化函数,其中 ``MODEL_INITIALIZER`` 是类 " @@ -117,109 +112,90 @@ msgstr "" " instantiated object of class ``internlm.util.registry.Registry``, the " "example is shown as follows." -#: ../../source/initialize.rst:72 d863f71b208a49a09d2d00537e331962 +#: ../../source/initialize.rst:72 msgid "优化器初始化" msgstr "Optimizer Initialization" -#: acaafdc9bb96434bbd42a98f74187db1 #: internlm.train.training_internlm.initialize_optimizer:1 of msgid "Initialize optimizer." msgstr "" -#: 62fc4215c9a44bda8b31c933db90f270 93c398e44f6a4f708ba064250a3d253c -#: e2bebdd751724915a65dec444bb89e25 #: internlm.initialize.initialize_trainer.initialize_trainer #: internlm.train.training_internlm.get_train_data_loader #: internlm.train.training_internlm.initialize_optimizer of msgid "参数" msgstr "" -#: 2033ee96ded8423a80268b337ba9549c #: internlm.train.training_internlm.initialize_optimizer:3 of msgid "Your model instance to be trained or evaluated." msgstr "" -#: df01b44c724b4326a6c85b44694262ba #: internlm.train.training_internlm.initialize_optimizer:6 of msgid "A tuple of (optimizer, beta2_scheduler, lr_scheduler)." msgstr "" -#: ../../source/initialize.rst:79 0b46b890048f4758a9d56e0540759d9f +#: ../../source/initialize.rst:79 msgid "数据加载器初始化" msgstr "Dataloader Initialization" -#: 58e39b26ab4849788e792df386f01d7e #: internlm.train.training_internlm.get_train_data_loader:1 of msgid "Generate and return the training data loader." msgstr "" -#: 37a91c167e0b4e5fad4edcc3caf0d012 #: internlm.train.training_internlm.get_train_data_loader:3 of msgid "number of subprocesses used for dataloader." msgstr "" -#: 947aba2a4f86420d9b2660425a6043cc #: internlm.train.training_internlm.get_train_data_loader:5 of msgid "generate function for dataset." msgstr "" -#: 8a8f5ee665cb4e15bc33194c0b1f346c #: internlm.train.training_internlm.get_train_data_loader:7 of msgid "dataset sampler for training dataloader." msgstr "" -#: 4c3e1e896e7940bf97c124909d2e7f36 #: internlm.train.training_internlm.get_train_data_loader:9 of msgid "collate function for training dataloader." msgstr "" -#: d9f0740d048c48888e82c8f8a78e33cd #: internlm.train.training_internlm.get_train_data_loader:12 of msgid "A tuple of (train_dl, dataset_types)." msgstr "" -#: ../../source/initialize.rst:86 1c4df708ff5c47f6abae32617bf2ed31 +#: ../../source/initialize.rst:86 msgid "Trainer 初始化" msgstr "Trainer Initialization" -#: d535583dbcb245499e19c09f3f8b534a #: internlm.initialize.initialize_trainer.initialize_trainer:1 of msgid "" "Core function to wrap the essential training components with our " "functionality based on the config which is loaded into gpc.config." msgstr "" -#: 3e370234e4b245e4b9cae1fe235df8ff #: internlm.initialize.initialize_trainer.initialize_trainer:4 of msgid "Your model instance or a function to build the model." msgstr "" -#: b716a4a264234011a7b51fa12e575651 #: internlm.initialize.initialize_trainer.initialize_trainer:6 of msgid "Your optimizer for training." msgstr "" -#: 6a54ce9d516f4f14bab281c9db9816e8 #: internlm.initialize.initialize_trainer.initialize_trainer:8 of msgid "Your criterion instance." msgstr "" -#: ff9dfd04d31b4dc6afbdd841829b4c33 #: internlm.initialize.initialize_trainer.initialize_trainer:10 of msgid "Dataloader for training." msgstr "" -#: de345f9a457a4a88bf60b4ee96535e31 #: internlm.initialize.initialize_trainer.initialize_trainer:12 of msgid "Dataloader for testing." msgstr "" -#: 64e646b25420424d9dcdfb1ad7de5e6f #: internlm.initialize.initialize_trainer.initialize_trainer:14 of msgid "Your lr scheduler instance, optional." msgstr "" -#: 39c7132bfafe4e22ae373081fee711ce #: internlm.initialize.initialize_trainer.initialize_trainer:17 of msgid "" "A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``" diff --git a/doc/code-docs/locales/en/LC_MESSAGES/profiler.po b/doc/code-docs/locales/en/LC_MESSAGES/profiler.po index 37aa6bb..71adf14 100644 --- a/doc/code-docs/locales/en/LC_MESSAGES/profiler.po +++ b/doc/code-docs/locales/en/LC_MESSAGES/profiler.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: InternLM \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-09-07 10:56+0800\n" +"POT-Creation-Date: 2023-09-08 15:32+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language: en\n" @@ -19,122 +19,147 @@ msgstr "" "Content-Transfer-Encoding: 8bit\n" "Generated-By: Babel 2.12.1\n" -#: ../../source/profiler.rst:2 81b1b5f4414449dfaf107815a911f300 +#: ../../source/profiler.rst:2 msgid "性能分析" msgstr "Profiler" -#: ../../source/profiler.rst:7 d709646ebb314e9abb6a4839a21180bd +#: ../../source/profiler.rst:7 msgid "Torch Profiler" msgstr "" -#: ../../source/profiler.rst:9 4b5b73486c794c7a9168ad19999e12e1 +#: ../../source/profiler.rst:9 msgid "" "InternLM 使用 ``internlm.train.initialize_llm_profile()`` " "来收集和分析模型训练或推理期间的性能数据,如 CPU/CUDA/memory 等性能数据。这个实现基于 `torch.profiler " "`_ ,输出的性能分析 trace 文件可以使用 " "`tensorboard `_ 进行可视化。" msgstr "" -"InternLM uses ``internlm.train.initialize_llm_profile()`` to profile performance data, execution time duration and breakdown analysis of " -"step time. The implementation is based on `torch.profiler `_ and output tracing files can " -"be visualized with `tensorboard `_." +"InternLM uses ``internlm.train.initialize_llm_profile()`` to profile " +"performance data, execution time duration and breakdown analysis of step " +"time. The implementation is based on `torch.profiler " +"`_ and output tracing " +"files can be visualized with `tensorboard `_." -#: ../../source/profiler.rst:11 40ff4289735c43fdbeca871b65e82be4 +#: ../../source/profiler.rst:11 msgid "" "用户如果想使用这个 torch 性能分析工具,需要在启动训练时传递 ``--profiling`` 参数以启用性能分析。完成 torch " "性能分析后,用户可以在 ``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` " "文件夹中看到性能分析结果。" msgstr "" -"To use this torch profiler tool, you need to enable profiling by passing the ``--profiling`` flag when starting training. After torch " -"profiling is completed, you can find the profiling results in the ``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` folder." +"To use this torch profiler tool, you need to enable profiling by passing " +"the ``--profiling`` flag when starting training. After torch profiling is" +" completed, you can find the profiling results in the " +"``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` folder." + +#: ../../source/profiler.rst:13 +msgid "实际运行生成的 ``Torch Profiler`` 目录结构如下:" +msgstr "The directory structure of ``Torch Profiler`` generated files is as follows:" + +#: ../../source/profiler.rst:22 +msgid "其中, ``traces`` 可以通过 ``TensorBoard`` 可视化,运行命令" +msgstr "Among them, ``traces`` can be visualized through ``TensorBoard`` and run with the command" + +#: ../../source/profiler.rst:29 +msgid "" +"在打开的 ``TensorBoard -> PyTorch Profiler -> Views -> Trace`` " +"页面可以看到Operator和GPU Kernel的性能分析时间线如下,更多的功能请参考 `torch profiler with " +"tensorboard " +"`_" +msgstr "In the opened ``TensorBoard -> PyTorch Profiler -> Views -> Trace`` page, you can see the timeline of profiled operators and GPU kernels. For more usage, please refer to `torch profiler with tensorboard `_" -#: 876a2993b82645f7b56553fe64b514ec #: internlm.train.training_internlm.initialize_llm_profile:1 of msgid "Initialize and return the profiler context manager instance." msgstr "" -#: ../../source/profiler.rst:16 3ab9536155ea4f3b8adb318005970bb8 +#: ../../source/profiler.rst:38 msgid "Memory Profiler" msgstr "" -#: ../../source/profiler.rst:18 0ec4091fef5b47c58488618bfb4dcd3b +#: ../../source/profiler.rst:40 msgid "" "InternLM 提供了一个实用的内存分析工具 " "``internlm.utils.simple_memory_profiler.SimpleMemoryProfiler`` 来监控实际的 GPU" " 内存使用情况。在实现中,会对模型数据(包括模型参数、模型梯度和优化器状态)和非模型数据(包括激活值)分别进行详细的统计。" msgstr "" -"InternLM provides a practical solution ``internlm.utils.simple_memory_profiler.SimpleMemoryProfiler`` to monitor actual GPU memory usage. " -"In the implmentation, model data (including model parameters, model gradients, and optimizer states) and non-model data " -"(including activations) are calculated." +"InternLM provides a practical solution " +"``internlm.utils.simple_memory_profiler.SimpleMemoryProfiler`` to monitor" +" actual GPU memory usage. In the implmentation, model data (including " +"model parameters, model gradients, and optimizer states) and non-model " +"data (including activations) are calculated." -#: ../../source/profiler.rst:20 cd62bbd5b122480da21e10453b95090c +#: ../../source/profiler.rst:42 msgid "" "要使用这个内存分析工具,用户需要在启动训练时传递 ``--profiling`` 参数以启用内存分析。完成内存分析后,用户可以在 " "``memory_trace/rank{}_dp{}_tp{}`` 文件夹中找到特定 rank " "对应的内存分析结果(包括不同时间点的内存使用日志和显示总体内存使用情况的太阳图表)。" msgstr "" -"To use this memory profiler tool, you need to enable profiling by passing the ``--profiling`` flag when starting training. After memory " -"profiling is completed, you can find the profiling results (including logs of memory usage at different time point and sunburst charts " -"showing overall memory usage) for a specific rank device in the ``memory_trace/rank{}_dp{}_tp{}`` folder." +"To use this memory profiler tool, you need to enable profiling by passing" +" the ``--profiling`` flag when starting training. After memory profiling " +"is completed, you can find the profiling results (including logs of " +"memory usage at different time point and sunburst charts showing overall " +"memory usage) for a specific rank device in the " +"``memory_trace/rank{}_dp{}_tp{}`` folder." + +#: ../../source/profiler.rst:44 +msgid "实际运行生成的 ``memory_trace`` 目录结构如下:" +msgstr "The directory structure of ``memory_trace`` generated files is as follows:" + +#: ../../source/profiler.rst:107 +msgid "其中, ``memory.log`` 的内容示例如下:" +msgstr "An example of ``memory.log`` is as follows:" + +#: ../../source/profiler.rst:157 +msgid "模型参数的太阳图示例如下:" +msgstr "An example of model parameters sunburst chart is as follows:" -#: a858f1377b714cd5ab0cf749d8dbfeb7 #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:1 of msgid "A memory profiler for a llm model." msgstr "" -#: 08d4cca2ba154080ba72e7d3fbd2a344 36e25696cf7b4a8ca5472e86fd5eea7e #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point of msgid "参数" msgstr "" -#: dea424767bc44ff689d582c67b07d637 #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:3 of msgid "The model to profile." msgstr "" -#: 4f3892910fa14324810c3f33c6af4fdd #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:5 of msgid "The optimizer used for training the model." msgstr "" -#: a698f2f57eef4e47a22faa546c687979 #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:7 of msgid "The file to write the memory state information to." msgstr "" -#: 448fc2b81c794d228ec4b413356289ea #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler:9 of msgid "number of steps to trace." msgstr "" -#: 85b3b9d4147547fd89c286f003395469 #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point:1 of msgid "Record the memory state." msgstr "" -#: d474a46415674d35a2c87c57ebff20ea #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point:3 of msgid "The options to include in the memory state. Defaults to \"\"." msgstr "" -#: 16261fe5b1df4b13bd23f76d97caf1be #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point:5 of msgid "Whether to create a new memory record file. Defaults to False." msgstr "" -#: 3b18845958204f07a6b80b6afb2221f5 d11f76d03d0d456889dee6d267dd4b74 #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.step of msgid "返回" msgstr "" -#: 0deeb9555efb4aa798fd9d146826e961 46b50da453f1475a88e096b5d6ed8afb #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.point:8 #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.step:3 of msgid "None" msgstr "" -#: 4f2331ac352d4057a852b013ca688ed3 #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.step:1 of msgid "Update the memory state of the optimizer state." msgstr "" diff --git a/doc/code-docs/source/profiler.rst b/doc/code-docs/source/profiler.rst index 4622d3c..7ff42cb 100644 --- a/doc/code-docs/source/profiler.rst +++ b/doc/code-docs/source/profiler.rst @@ -10,6 +10,28 @@ InternLM 使用 ``internlm.train.initialize_llm_profile()`` 来收集和分析 用户如果想使用这个 torch 性能分析工具,需要在启动训练时传递 ``--profiling`` 参数以启用性能分析。完成 torch 性能分析后,用户可以在 ``{JOB_NAME}/{start_time}/traces/rank{}_dp{}_tp{}_pp{}`` 文件夹中看到性能分析结果。 +实际运行生成的 ``Torch Profiler`` 目录结构如下: + +.. code-block:: bash + + # tree ./7b_train/Sep08_11-00-51/traces -L 2 + ./7b_train/Sep08_11-00-51/traces/ + └── rank0_dp0_tp0_pp0 + └── SH-IDC1-10-140-1-78_238619.1694142354680.pt.trace.json + +其中, ``traces`` 可以通过 ``TensorBoard`` 可视化,运行命令 + +.. code-block:: bash + + # visualize traces with tensorboard and custom port + tensorboard --logdir rank0_dp0_tp0_pp0 --port 10088 + +在打开的 ``TensorBoard -> PyTorch Profiler -> Views -> Trace`` 页面可以看到Operator和GPU Kernel的性能分析时间线如下,更多的功能请参考 `torch profiler with tensorboard `_ + +.. figure:: ../../imgs/torch_profiler_trace.png + :scale: 45% + :class: with-border + .. autofunction:: internlm.train.initialize_llm_profile Memory Profiler @@ -19,5 +41,124 @@ InternLM 提供了一个实用的内存分析工具 ``internlm.utils.simple_memo 要使用这个内存分析工具,用户需要在启动训练时传递 ``--profiling`` 参数以启用内存分析。完成内存分析后,用户可以在 ``memory_trace/rank{}_dp{}_tp{}`` 文件夹中找到特定 rank 对应的内存分析结果(包括不同时间点的内存使用日志和显示总体内存使用情况的太阳图表)。 +实际运行生成的 ``memory_trace`` 目录结构如下: + +.. code-block:: bash + + # tree ./memory_trace -L 2 + ./memory_trace + ├── rank0_dp0_tp0 # Profiling results for a specific rank device + │   ├── activation_memory_sunburst.html # Sunburst chart showing activation memory usage + │   ├── grads_memory_sunburst.html # Sunburst chart showing gradient memory usage + │   ├── memory.log # Log of GPU memory usage at different time points + │   ├── os_memory_sunburst.html # Sunburst chart showing optimizer state memory usage + │   ├── params_memory_sunburst.html # Sunburst chart showing parameter memory usage + │   └── summary_sunburst.html # Sunburst chart showing overall memory usage + ├── rank1_dp1_tp0 + │   ├── activation_memory_sunburst.html + │   ├── grads_memory_sunburst.html + │   ├── memory.log + │   ├── os_memory_sunburst.html + │   ├── params_memory_sunburst.html + │   └── summary_sunburst.html + ├── rank2_dp2_tp0 + │   ├── activation_memory_sunburst.html + │   ├── grads_memory_sunburst.html + │   ├── memory.log + │   ├── os_memory_sunburst.html + │   ├── params_memory_sunburst.html + │   └── summary_sunburst.html + ├── rank3_dp3_tp0 + │   ├── activation_memory_sunburst.html + │   ├── grads_memory_sunburst.html + │   ├── memory.log + │   ├── os_memory_sunburst.html + │   ├── params_memory_sunburst.html + │   └── summary_sunburst.html + ├── rank4_dp4_tp0 + │   ├── activation_memory_sunburst.html + │   ├── grads_memory_sunburst.html + │   ├── memory.log + │   ├── os_memory_sunburst.html + │   ├── params_memory_sunburst.html + │   └── summary_sunburst.html + ├── rank5_dp5_tp0 + │   ├── activation_memory_sunburst.html + │   ├── grads_memory_sunburst.html + │   ├── memory.log + │   ├── os_memory_sunburst.html + │   ├── params_memory_sunburst.html + │   └── summary_sunburst.html + ├── rank6_dp6_tp0 + │   ├── activation_memory_sunburst.html + │   ├── grads_memory_sunburst.html + │   ├── memory.log + │   ├── os_memory_sunburst.html + │   ├── params_memory_sunburst.html + │   └── summary_sunburst.html + └── rank7_dp7_tp0 + ├── activation_memory_sunburst.html + ├── grads_memory_sunburst.html + ├── memory.log + ├── os_memory_sunburst.html + ├── params_memory_sunburst.html + └── summary_sunburst.html + +其中, ``memory.log`` 的内容示例如下: + +.. code-block:: bash + + Memory State: + time: 37.56313228607178 + ---summary--- + total_memory: 55953.56 MB + params_memory: 13965.51 MB, grads_memory: 13965.51 MB, os_params_memory: 3461.52 MB, os_state_memory: 6923.03 MB, activation_memory: 17638.00 MB + + Memory State: + time: 38.46969723701477 + ---summary--- + total_memory: 38315.56 MB + params_memory: 13965.51 MB, grads_memory: 13965.51 MB, os_params_memory: 3461.52 MB, os_state_memory: 6923.03 MB, activation_memory: 0.00 MB + ---Layout--- + params_layout: + layer: param_mem, layer_mem: 0.00 MB, total_mem: 13965.51 MB + layer: param_mem.embedding, layer_mem: 0.00 MB, total_mem: 806.00 MB + layer: param_mem.embedding.weight, layer_mem: 806.00 MB, total_mem: 806.00 MB + layer: param_mem.blocks, layer_mem: 0.00 MB, total_mem: 12353.50 MB + layer: param_mem.blocks.0, layer_mem: 0.00 MB, total_mem: 386.05 MB + layer: param_mem.blocks.0.mixer, layer_mem: 0.00 MB, total_mem: 128.03 MB + layer: param_mem.blocks.0.mixer.Wqkv, layer_mem: 0.00 MB, total_mem: 96.02 MB + layer: param_mem.blocks.0.mixer.Wqkv.weight, layer_mem: 96.00 MB, total_mem: 96.00 MB + layer: param_mem.blocks.0.mixer.Wqkv.bias, layer_mem: 0.02 MB, total_mem: 0.02 MB + layer: param_mem.blocks.0.mixer.out_proj, layer_mem: 0.00 MB, total_mem: 32.01 MB + layer: param_mem.blocks.0.mixer.out_proj.weight, layer_mem: 32.00 MB, total_mem: 32.00 MB + layer: param_mem.blocks.0.mixer.out_proj.bias, layer_mem: 0.01 MB, total_mem: 0.01 MB + layer: param_mem.blocks.0.norm1, layer_mem: 0.00 MB, total_mem: 0.01 MB + layer: param_mem.blocks.0.norm1.weight, layer_mem: 0.01 MB, total_mem: 0.01 MB + layer: param_mem.blocks.0.norm2, layer_mem: 0.00 MB, total_mem: 0.01 MB + layer: param_mem.blocks.0.norm2.weight, layer_mem: 0.01 MB, total_mem: 0.01 MB + layer: param_mem.blocks.0.mlp, layer_mem: 0.00 MB, total_mem: 258.00 MB + layer: param_mem.blocks.0.mlp.w1, layer_mem: 0.00 MB, total_mem: 86.00 MB + layer: param_mem.blocks.0.mlp.w1.weight, layer_mem: 86.00 MB, total_mem: 86.00 MB + layer: param_mem.blocks.0.mlp.w2, layer_mem: 0.00 MB, total_mem: 86.00 MB + layer: param_mem.blocks.0.mlp.w2.weight, layer_mem: 86.00 MB, total_mem: 86.00 MB + layer: param_mem.blocks.0.mlp.w3, layer_mem: 0.00 MB, total_mem: 86.00 MB + layer: param_mem.blocks.0.mlp.w3.weight, layer_mem: 86.00 MB, total_mem: 86.00 MB + ...... + grads_layout: + ...... + os_params_layout: + ...... + os_state_layout: + ...... + activation_base_layout: + ...... + +模型参数的太阳图示例如下: + +.. figure:: ../../imgs/params_memory_sunburst.png + :scale: 50% + :class: with-border + .. autoclass:: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler :members: diff --git a/doc/en/usage.md b/doc/en/usage.md index f8809d0..d115fb1 100644 --- a/doc/en/usage.md +++ b/doc/en/usage.md @@ -112,19 +112,19 @@ If you want to load a model checkpoint when starting the training, you can confi ```python SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt" -MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt" LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt" ckpt = dict( save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save the model and optimizer checkpoints checkpoint_every=float("inf"), # Save a checkpoint every specified number of steps, default value is inf - load_model_only_folder=MODEL_ONLY_FOLDER, # Path to load the initial model weights, only load model weights without loading optimizer weights, training will start from the first step - load_ckpt_folder=LOAD_CKPT_FOLDER, # Path to load the weights of the model and optimizer for resuming training, training will resume from the specified step - load_optimizer=True, # Whether to load optimizer weights when resuming training, default value is True + # When resuming training from a breakpoint,: + # (1) 'path' is the path of the loaded checkpoint. + # (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), ) ``` Note: -- `load_model_only_folder` and `load_ckpt_folder` cannot be set at the same time. - If the path starts with `local:`, it means the file is stored in the local file system. If it starts with `boto3:`, it means the file is stored in the remote OSS. The configuration for the model is as follows: diff --git a/doc/imgs/params_memory_sunburst.png b/doc/imgs/params_memory_sunburst.png new file mode 100644 index 0000000..c3ee8bc Binary files /dev/null and b/doc/imgs/params_memory_sunburst.png differ diff --git a/doc/imgs/torch_profiler_trace.png b/doc/imgs/torch_profiler_trace.png new file mode 100644 index 0000000..76129ae Binary files /dev/null and b/doc/imgs/torch_profiler_trace.png differ diff --git a/doc/usage.md b/doc/usage.md index c00b03e..1b98c10 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -101,18 +101,17 @@ data = dict( 如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置: ```python SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt" -MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt" LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt" ckpt = dict( save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径 checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf - load_model_only_folder=MODEL_ONLY_FOLDER, # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始 - load_ckpt_folder=LOAD_CKPT_FOLDER, # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练 - load_optimizer=True, # 断点续训时,是否需要加载优化器权重,默认值为 True + # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练 + # content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all" + # ckpt_type 表示加载的模型类型,目前支持: "internlm" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), ) ``` 注意: -- `load_model_only_folder`与`load_ckpt_folder`不能同时设置 - 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上 模型相关关键参数配置如下所示: diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index f1de5ad..968489c 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -18,6 +18,7 @@ import torch.distributed as dist from internlm.utils.common import SingletonMeta from internlm.utils.logger import get_logger +from internlm.utils.timeout import LLM_NCCL_TIMEOUT from . import process_group_initializer as pgroup_initializer from .process_group_initializer import ParallelMode @@ -36,7 +37,7 @@ class Config(dict): config (dict): The dict object to be wrapped. """ - def __init__(self, config: dict = None): + def __init__(self, config: dict = None): # pylint: disable=W0231 if config is not None: for k, v in config.items(): self._add_item(k, v) @@ -100,7 +101,7 @@ class Config(dict): module_name = filepath.stem source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) - module = source_file.load_module() # pylint: disable=W4902,E1120 + module = source_file.load_module() # pylint: disable=W4902,E1120,W1505 # load into config config = Config() @@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta): """ # initialize the default process group init_method = f"tcp://[{host}]:{port}" - dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) + dist.init_process_group( + rank=rank, + world_size=world_size, + backend=backend, + init_method=init_method, + timeout=LLM_NCCL_TIMEOUT, + ) # None will give the default global process group for pytorch dist operations ranks = list(range(world_size)) if use_cpu: - cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None + cpu_group = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else None + ) else: cpu_group = None self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) @@ -526,6 +537,7 @@ class ParallelContext(metaclass=SingletonMeta): if dpseed_with_tpoffset: dp_seed = seed + pipeline_offset * 1024 add_seed(ParallelMode.DATA, dp_seed) + add_seed(ParallelMode.DUMMY, dp_seed) # model parallel seeds are different across ranks if self.is_initialized(ParallelMode.TENSOR): @@ -533,7 +545,11 @@ class ParallelContext(metaclass=SingletonMeta): tp_seed = seed + tp_rank + pipeline_offset * 1024 add_seed(ParallelMode.TENSOR, tp_seed) - set_mode(ParallelMode.DATA) + # we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode + # during model construction), this is because the random state will be different in different tensor parallel + # device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform + # additional random operations during the RowParallelLinear module building process. + set_mode(ParallelMode.DUMMY) seeds = get_seeds() seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index facb806..97e9ef0 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -9,6 +9,8 @@ from enum import Enum import torch.distributed as dist +from internlm.utils.timeout import LLM_NCCL_TIMEOUT + # parallel modes class ParallelMode(Enum): @@ -35,6 +37,9 @@ class ParallelMode(Enum): # runntime network test NETTEST = "nettest" + # dummy mode, only used during mode construction + DUMMY = "dummy" + class ProcessGroupInitializer(ABC): """An object, knowing the parallelism configuration, that initializes parallel groups. @@ -106,9 +111,13 @@ class Initializer_Data(ProcessGroupInitializer): for i in range(self.rank_num_per_dp_group): ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -158,9 +167,13 @@ class Initializer_Model(ProcessGroupInitializer): for i in range(self.num_group): ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -218,9 +231,13 @@ class Initializer_Pipeline(ProcessGroupInitializer): ) ) pipe_group_size = len(ranks) - pipe_group = dist.new_group(ranks) + pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else pipe_group + ) else: group_cpu = None @@ -268,9 +285,13 @@ class Initializer_Tensor(ProcessGroupInitializer): for i in range(self.num_tensor_parallel_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -324,9 +345,13 @@ class Initializer_Zero1(ProcessGroupInitializer): i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group for k in range(self.zero1_parallel_size) ] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -373,9 +398,13 @@ class Initializer_Nettest(ProcessGroupInitializer): rank = i * self.nettest_parallel_size + j if rank < self.world_size: ranks.append(rank) - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 0076349..1d8b61e 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -9,6 +9,7 @@ import torch from internlm.core.engine import Engine from internlm.utils.common import conditional_context +from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler, SchedulerHook @@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler): return output, loss + @llm_timeout(func_name="nopp_forward_backward_step") def forward_backward_step( self, engine: Engine, diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index fd0b23e..e9b6c64 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -15,6 +15,7 @@ from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel from internlm.utils.common import get_current_device, move_to_device from internlm.utils.logger import get_logger +from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler, SchedulerHook @@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler): return output, label, accum_loss + @llm_timeout(func_name="nointerleaved_forward_backward_step") def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -1247,6 +1249,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): # 3. Cooldown self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) + @llm_timeout(func_name="interleaved_forward_backward_step") def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 6fd40ce..18a8f6f 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -23,7 +23,15 @@ class TrainState: train_dl (DataLoader): The DataLoader object used for training. """ - def __init__(self, config) -> None: + def __init__(self, config, batch_sampler) -> None: + """ + Args: + config (Config): internlm config + batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is + asynchronous and prefetched, the batch_sampler state maintained inside the + dataloader are faster then the actual training progress, so we copy the + batch_sampler as the anchor point of ckpt reload. + """ # The number of batches produced by the data iterator self.batch_count: int = 0 # Used to store the number of samples consumed in the current epoch @@ -43,9 +51,20 @@ class TrainState: self.tensorboard_folder = config.tensorboard_folder - def init_batch_sampler(self, train_dl): - # Copy of the batch sampler from the DataLoader - self.batch_sampler = train_dl.batch_sampler.copy() + # learning rate + self.lr = config.adam.lr + + # smapler state + if batch_sampler: + self.init_batch_sampler(batch_sampler) + + def init_batch_sampler(self, batch_sampler): + """ + Args: + batch_sampler (torch.utils.data.Sampler): sampler. + """ + # make a copy of batch_sampler. + self.batch_sampler = batch_sampler.copy() # Iterator for the batch sampler self.batch_sampler_iter = iter(self.batch_sampler) @@ -61,26 +80,22 @@ class TrainState: return json.dumps(info, indent=4, sort_keys=True) - def load_state_dict(self, other_stuffs, train_dl): + def load_state_dict(self, other_stuffs): """ Resumes training from a checkpoint. Args: other_stuffs (dict): Other information needed to resume training. - train_dl (DataLoader): The DataLoader object used for training. """ - - self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"] self.num_consumed_tokens = other_stuffs["num_consumed_tokens"] self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"] - # compatible with previous checkpoints without this parameter - self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1 - # track the actual updates of sampler when using weighted sampling - if hasattr(self, "batch_sampler"): - self.batch_sampler = train_dl.batch_sampler.copy() - self.batch_sampler_iter = iter(self.batch_sampler) + # Because the ckpt save occurs after updating 'step_count', + # there is no need to increment 'step_count' here (Does our step count start from 0 ?), + # However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume. + self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward + self.step_count = other_stuffs.get("step_count", self.batch_count) # resume tensorboard from older tensorboard_folder self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 576a53a..079c2cb 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -10,9 +10,10 @@ import torch from internlm.core.context import Config from internlm.core.context import global_context as gpc +from internlm.monitor import initialize_light_monitor from internlm.utils.common import get_master_node from internlm.utils.logger import get_logger -from internlm.utils.storage_manager import init_storage_manager +from internlm.utils.timeout import llm_timeout logger = get_logger(__file__) @@ -97,6 +98,13 @@ def args_sanity_check(): if "valid_every" not in data: data._add_item("valid_every", 0) + if "empty_cache_and_diag_interval" not in data: + data._add_item("empty_cache_and_diag_interval", 50) + + if "diag_outlier_ratio" not in data: + data._add_item("diag_outlier_ratio", 1.1) + data.diag_outlier_ratio = max(1, data.diag_outlier_ratio) + if gpc.is_rank_for_log(): logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"seq_len: {data.seq_len}") @@ -111,7 +119,7 @@ def args_sanity_check(): # processing the checkpoint config ckpt = gpc.config.ckpt if "enable_save_ckpt" not in ckpt: - ckpt._add_item("enable_save_ckpt", False) + ckpt._add_item("enable_save_ckpt", True) # Saving checkpoint args. if ckpt.enable_save_ckpt: @@ -137,9 +145,6 @@ def args_sanity_check(): if not ckpt.async_upload: ckpt._add_item("async_upload_tmp_folder", None) - if "snapshot_ckpt_folder" not in ckpt: - ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot")) - if "oss_snapshot_freq" not in ckpt: ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable. else: @@ -149,44 +154,23 @@ def args_sanity_check(): ckpt._add_item("async_upload", False) ckpt._add_item("async_upload_tmp_folder", None) ckpt._add_item("snapshot_ckpt_folder", None) - ckpt._add_item("snapshot_ckpt_folder", None) - - # Loading checkpoint args. - if "load_model_only_folder" not in ckpt: - ckpt._add_item("load_model_only_folder", None) if "load_ckpt_folder" not in ckpt: ckpt._add_item("load_ckpt_folder", None) - if "load_optimizer" not in ckpt: - ckpt._add_item("load_optimizer", True) - if "stop_file_path" not in ckpt: ckpt._add_item("stop_file_path", None) - if "load_given_ckpt" not in ckpt: - # If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity + if "auto_resume" not in ckpt: + # If 'auto_resume' is not given, we set it to True, so internlm can have opportunity # to auto-load latest checkpoint. - ckpt._add_item("load_given_ckpt", False) - - if ckpt.load_given_ckpt: - # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder - if ckpt.load_ckpt_folder and ckpt.load_model_only_folder: - logger.warning( - "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \ -and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" - ) - ckpt.load_model_only_folder = None + ckpt._add_item("auto_resume", True) if gpc.is_rank_for_log(): logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}") logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}") logger.info(f"checkpoint_every: {ckpt.checkpoint_every}") - logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}") - - # initialization storage manager - init_storage_manager(ckpt) # tensorboard writer config if "enable_tb" not in gpc.config: @@ -277,9 +261,22 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False ), "sequence parallel does not support use_flash_attn=False" - # feishu webhook address for alerting - if "alert_address" not in gpc.config: - gpc.config._add_item("alert_address", None) + # monitoring default config + monitor_default_config = { + "alert_address": None, # compatible with old alert config + "monitor": { # new monitoring config + "alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None} + }, + } + + for key, value in monitor_default_config.items(): + if key not in gpc.config: + gpc.config._add_item(key, value) + + alert = gpc.config.monitor.alert + + if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log(): + logger.warning("alert is enable but alert_address is not set") optim_ckpt = gpc.config.hybrid_zero_optimizer if "zero_overlap_communication" in optim_ckpt: @@ -426,6 +423,7 @@ def launch_from_torch( ) +@llm_timeout(func_name="initialize_distributed_env") def initialize_distributed_env( config: str, launcher: str = "slurm", @@ -459,3 +457,20 @@ def initialize_distributed_env( if args_check: args_sanity_check() + + # init light monitor client + alert_config = gpc.config.monitor.alert + if alert_config.enable_feishu_alert and gpc.is_rank_for_log(): + light_monitor_address = alert_config.light_monitor_address + if light_monitor_address: + initialize_light_monitor(light_monitor_address) + else: + logger.warning("monitor address is none, monitor could not be used!") + + +def get_config_value(config, key, defalut): + try: + value = config[key] + except KeyError: + value = defalut + return value diff --git a/internlm/initialize/legacy/__init__.py b/internlm/initialize/legacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/internlm/initialize/legacy/launch.py b/internlm/initialize/legacy/launch.py new file mode 100644 index 0000000..8313654 --- /dev/null +++ b/internlm/initialize/legacy/launch.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from internlm.initialize.launch import get_config_value +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +def auto_resume_sanity_check(ckpt_config): + load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None) + if load_given_ckpt is None: + return True # default value is True + else: + return not load_given_ckpt + + +def ckpt_info_sanity_check(ckpt_config): + load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None) + + load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None) + + if load_model_only_folder is not None: + assert ( + load_ckpt_folder is None + ), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \ +# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" + return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm") + else: + load_optimizer = get_config_value(ckpt_config, "load_optimizer", True) + + if isinstance(load_ckpt_folder, str): + if load_optimizer: + return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm") + else: + return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm") + elif load_ckpt_folder is None: + return None + else: + assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'" diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 32f29f8..5a3a4eb 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -9,7 +9,7 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.utils import fused_dense_func_torch @@ -195,12 +195,6 @@ class FeedForward(nn.Module): device=device, dtype=dtype, ) - # need to assign tp attribute so that colossalai know it is tensor parallel module - - if gpc.get_world_size(ParallelMode.TENSOR) > 1: - for name in ["w1", "w2", "w3"]: - for param in getattr(self, name).parameters(): - setattr(param, IS_TENSOR_PARALLEL, True) def forward(self, x): out = self.w3(F.silu(self.w1(x)) * self.w2(x)) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index ceb4ac3..64ff4de 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -127,6 +127,9 @@ class PackedFlashBaseLayer1D(nn.Module): device=device, dtype=dtype, ) + for _, param in self.mlp.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init diff --git a/internlm/monitor/__init__.py b/internlm/monitor/__init__.py index b100cde..2501d66 100644 --- a/internlm/monitor/__init__.py +++ b/internlm/monitor/__init__.py @@ -1,4 +1,11 @@ +from .alert import initialize_light_monitor, send_heartbeat from .monitor import initialize_monitor_manager, send_alert_message from .utils import set_env_var -__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"] +__all__ = [ + "send_alert_message", + "initialize_monitor_manager", + "set_env_var", + "initialize_light_monitor", + "send_heartbeat", +] diff --git a/internlm/monitor/alert.py b/internlm/monitor/alert.py index 78b6040..1772e7f 100644 --- a/internlm/monitor/alert.py +++ b/internlm/monitor/alert.py @@ -1,8 +1,59 @@ import json +import math +import os +import re import time +from typing import Dict import requests +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +def initialize_light_monitor(monitor_address: str = None): + try: + from uniscale_monitoring import init_monitor + + init_monitor(monitor_address) + except Exception as e: + logger.warning(f"init monitor meet error: {e}") + + +def send_heartbeat(msg_type: str, msg: Dict): + def nan2none(v): + if isinstance(v, float) and math.isnan(v): + return None + return v + + try: + from uniscale_monitoring import send_meta + + data = {} + for k, v in msg.items(): + if isinstance(v, Dict): + for k1, v1 in v.items(): + new_k = f"{k}_{k1}".split(" ")[0] + new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k) + data[new_k] = nan2none(v1) + else: + new_k = k.split(" ")[0] + new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k) + data[new_k] = nan2none(v) + + if os.getenv("CLUSTER_NAME"): + data.update({"cluster": os.getenv("CLUSTER_NAME")}) + if msg_type == "train_metrics": + data.update({"msg_type": "train_metrics"}) + elif msg_type == "init_time": + data.update({"msg_type": "init_time"}) + elif msg_type == "stage_time": + data.update({"msg_type": "stage_time"}) + send_meta(data, timeout=0.1) + except Exception as e: + logger.warning(f"send heartbeat meet error: {e}") + def send_feishu_msg_with_webhook(webhook: str, title: str, message: str): """ diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py index a8ef5a0..6a3b9dc 100644 --- a/internlm/monitor/monitor.py +++ b/internlm/monitor/monitor.py @@ -226,9 +226,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None): send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.") yield finally: - send_alert_message( - address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed." - ) + send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.") monitor_manager.stop_monitor() else: yield diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py index 3da5bbe..99051f4 100644 --- a/internlm/solver/optimizer/__init__.py +++ b/internlm/solver/optimizer/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from .hybrid_zero_optim import HybridZeroOptimizer +from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff -__all__ = ["HybridZeroOptimizer"] +__all__ = ["HybridZeroOptimizer", "reload_zero_fp32_buff"] diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 700d0dc..5031fd3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import ( from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.timeout import llm_timeout from .utils import compute_norm @@ -124,6 +125,7 @@ class HybridZeroOptimizer(BaseOptimizer): self._param_store = ParameterStore(ParallelMode.ZERO1) self._grad_store = GradientStore(ParallelMode.DATA) self._bucket_store = BucketStore(ParallelMode.DATA) + self._bucket_in_progress = [] # fp16 and fp32 params for mixed precision training self._fp16_param_groups = dict() @@ -133,6 +135,8 @@ class HybridZeroOptimizer(BaseOptimizer): # self._overlap_communication = overlap_communication self._reduce_bucket_size = reduce_bucket_size + self._comm_bcast_stream = torch.cuda.Stream() + # gradient scaler self.grad_scaler = DynamicGradScaler( initial_scale=initial_scale, @@ -231,13 +235,6 @@ class HybridZeroOptimizer(BaseOptimizer): # flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled. self.skip_grad_reduce = False - # initialize communication stream for - # communication-computation overlapping - if self._overlap_sync_grad: - self._comm_stream = torch.cuda.Stream() - else: - self._comm_stream = torch.cuda.current_stream() - # reduction hook is only used if overlapping communication # if it is stage 1 without overlapping, no hook will be attached if self._overlap_sync_grad: @@ -383,34 +380,41 @@ class HybridZeroOptimizer(BaseOptimizer): def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): grad_buckets_by_dtype = split_half_float_double(grads) - + next_bucket_list = [] + # add parameters into bucket for reduction for tensor_list in grad_buckets_by_dtype: param_bucket = TensorBucket(size=bucket_size) for tensor in tensor_list: param_bucket.add_to_bucket(tensor, allow_oversize=True) - if param_bucket.is_full_or_oversized(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) - param_bucket.empty() if not param_bucket.is_empty(): self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + next_bucket_list.append(param_bucket) + + # wait for the completion of previouce bucket list reduction, and do unflatten_and_copy() + # here we can also overlap the communication with some memcpy operation caused by bucket.flatten() + for bucket in self._bucket_in_progress: + bucket.commu_handle.wait() + bucket.unflatten_and_copy() + bucket.empty() + self._bucket_in_progress = [] + self._param_store.clear_grads_of_previous_reduced_params() + + # after the completion of bucket list reduction, add new buckets into _bucket_in_progress + self._bucket_in_progress = next_bucket_list.copy() def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): - if self._overlap_sync_grad: - self._comm_stream.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() + # flatten the tensors and do allreduce + bucket.flatten() + bucket.commu_handle = reduce_tensor( + tensor=bucket.get_flat_tensor(), + dtype=None, + dst_rank=reduce_rank, + parallel_mode=ParallelMode.DATA, + ) - with torch.cuda.stream(self._comm_stream): - flat = bucket.flatten() - reduced_flat = reduce_tensor( - tensor=flat, - dtype=self.dtype, - dst_rank=reduce_rank, - parallel_mode=ParallelMode.DATA, - ) - - # update the reduced tensor - if reduce_rank is None or reduce_rank == self._zero_local_rank: - bucket.unflatten_and_copy(reduced_flat) + # update the reduced tensor + if reduce_rank is None or reduce_rank == self._zero_local_rank: + bucket.set_unflatten_and_copy_flag(flag=True) def _has_inf_or_nan(self, tensor): try: @@ -506,6 +510,7 @@ class HybridZeroOptimizer(BaseOptimizer): return norm + @llm_timeout(func_name="optim_step") def step(self, closure=None): """Performs a single optimization step. @@ -534,10 +539,13 @@ class HybridZeroOptimizer(BaseOptimizer): groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) # clear reduced grads - if self._overlap_sync_grad: - # grads in the last bucket is reduced - self._comm_stream.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() + # grads in the last bucket is reduced + for bucket in self._bucket_in_progress: + bucket.commu_handle.wait() + bucket.unflatten_and_copy() + bucket.empty() + self._bucket_in_progress = [] + self._param_store.clear_grads_of_previous_reduced_params() # compute norm for gradients in the last bucket total_norms = {} @@ -562,6 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer): # check for overflow found_inf = False + found_nan = False # if there is INF values in grades, compute_norm func would also returns -1 # thus, we try to avoid call _check_overflow here # found_inf = self._check_overflow() @@ -570,21 +579,36 @@ class HybridZeroOptimizer(BaseOptimizer): if -1 in norms.values(): found_inf = True + if -2 in norms.values(): + found_nan = True + loss_scale = float(self.loss_scale.item()) # backup if gpc.config.model.dtype is not torch.float32: self.grad_scaler.update(found_inf) + # update loss scale if overflow occurs if found_inf: if gpc.is_rank_for_log(): logger.warning("Overflow occurs, please check it.") send_alert_message( - address=gpc.config.alert_address, + address=gpc.config.monitor.alert.feishu_alert_address, message="Overflow occurs, please check it.", ) self._grad_store._averaged_gradients = dict() self.zero_grad() return False, norms + if found_nan: + if gpc.is_rank_for_log(): + logger.warning("Nan grad norm occurs, please check it.") + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message="Nan grad norm occurs, please check it.", + ) + self._grad_store._averaged_gradients = dict() + self.zero_grad() + return False, norms + # copy the grad of fp16 param to fp32 param single_grad_partition_groups = [] for group_id in range(self.num_param_groups): @@ -624,7 +648,9 @@ class HybridZeroOptimizer(BaseOptimizer): if gpc.config.model.dtype is not torch.float32: if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0: self._unscale_and_clip_grads( - single_grad_partition_groups, list(global_norm_groups.values()), loss_scale + single_grad_partition_groups, + list(global_norm_groups.values()), + loss_scale, ) # update the parameters @@ -645,7 +671,9 @@ class HybridZeroOptimizer(BaseOptimizer): fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - self.broadcast_params() + torch.cuda.synchronize() + with torch.cuda.stream(self._comm_bcast_stream): + self.broadcast_params() timer("step").stop() @@ -771,3 +799,17 @@ class HybridZeroOptimizer(BaseOptimizer): if "zero_devide_optim_plan" in states: self.params_per_rank_id_dict = states["zero_devide_optim_plan"] + + +def reload_zero_fp32_buff(optimizer): + # If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value. + # Or we must ensure that loading model weights must be done before zero is initialized. + if isinstance(optimizer, HybridZeroOptimizer): + for group_id, param_group in enumerate(optimizer.optim.param_groups): + if optimizer.param_group_has_params[group_id]: + # flatten fp16 params have already been updated by 'load_model_checkpoint' + fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group( + optimizer._zero_local_rank, group_id + ) + # param_group["params"] is fp32 flatten optimizer states of this zero rank. + param_group["params"][0].data.copy_(fp16_flat_current_rank.float()) diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 05a44d2..adab6c9 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -249,11 +249,17 @@ class ParameterStore(BaseStore): if not last_bucket: if group_id not in self._former_bucket_reduced_param: return [], [] - return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id] + return ( + self._former_bucket_reduced_param[group_id], + self._former_bucket_reduced_grad[group_id], + ) else: if group_id not in self._last_bucket_reduced_param: return [], [] - return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id] + return ( + self._last_bucket_reduced_param[group_id], + self._last_bucket_reduced_grad[group_id], + ) def reset_reduced_data_for_compute_norm(self): self._former_bucket_reduced_param = {} @@ -277,6 +283,9 @@ class TensorBucket: self._max_size = size self._current_size = 0 self._bucket = [] + self._flat_tensor = None + self._unflatten_and_copy_flag = False + self.commu_handle = None @property def max_size(self): @@ -292,6 +301,15 @@ class TensorBucket: def is_empty(self): return len(self._bucket) == 0 + def set_unflatten_and_copy_flag(self, flag): + self._unflatten_and_copy_flag = flag + + def get_unflatten_and_copy_flag(self): + return self._unflatten_and_copy_flag + + def get_flat_tensor(self): + return self._flat_tensor + def add_to_bucket(self, tensor, allow_oversize=False): tensor_size = tensor.numel() @@ -312,11 +330,14 @@ class TensorBucket: def empty(self): self._bucket = [] self._size = 0 + self._flat_tensor = None + self.commu_handle = None def flatten(self): - return _flatten_dense_tensors(self._bucket) + self._flat_tensor = _flatten_dense_tensors(self._bucket) - def unflatten_and_copy(self, flat_tensor): - unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket) - for old, new in zip(self._bucket, unflattened_tensor_list): - old.copy_(new) + def unflatten_and_copy(self): + if self._unflatten_and_copy_flag: + unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket) + for old, new in zip(self._bucket, unflattened_tensor_list): + old.copy_(new) diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 38e4560..dbfcc34 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -95,37 +95,34 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode. :type parallel_mode: ParallelMode, optional """ # use the original dtype - if dtype is None: - dtype = tensor.dtype + # if dtype is None: + assert dtype is None + dtype = tensor.dtype # cast the data to specified dtype for reduce/all-reduce - if tensor.dtype != dtype: - tensor_to_reduce = tensor.to(dtype) - else: - tensor_to_reduce = tensor + # if tensor.dtype != dtype: + # tensor_to_reduce = tensor.to(dtype) + # else: + # tensor_to_reduce = tensor - world_size = gpc.get_world_size(parallel_mode) + # world_size = gpc.get_world_size(parallel_mode) + # tensor.div_(world_size) group = gpc.get_group(parallel_mode) - tensor_to_reduce.div_(world_size) # if rank is None, all reduce will be used # else, reduce is used use_all_reduce = dst_rank is None if use_all_reduce: - dist.all_reduce(tensor_to_reduce, group=group) + handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True) else: ranks_in_group = gpc.get_ranks_in_group(parallel_mode) global_rank = ranks_in_group[dst_rank] - dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group) + handle = dist.reduce( + tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True + ) - # recover the original dtype - if tensor.dtype != dtype and tensor is not tensor_to_reduce: - local_rank = gpc.get_local_rank(parallel_mode) - if use_all_reduce or dst_rank == local_rank: - tensor.copy_(tensor_to_reduce) - - return tensor + return handle def has_inf_or_nan(tensor): @@ -314,6 +311,9 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no if total_norm == float("inf") or total_norm == -float("inf"): total_norm = -1 + if math.isnan(total_norm): + total_norm = -2 + return total_norm diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 3f4c851..a24317e 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -12,6 +12,7 @@ from torch.utils.data import ConcatDataset, DataLoader from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.context.random import set_mode from internlm.core.naive_amp import NaiveAMPModel from internlm.core.trainer import TrainState from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader @@ -24,7 +25,7 @@ from internlm.data.packed_dataset import ( get_packed_dataset_without_short_length, ) from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data -from internlm.monitor import set_env_var +from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR @@ -39,10 +40,12 @@ from internlm.utils.parallel import ( sync_model_param_within_tp, ) from internlm.utils.registry import MODEL_INITIALIZER +from internlm.utils.timeout import llm_timeout logger = get_logger(__file__) +@llm_timeout(func_name="initialize_model") def initialize_model(): """ Initialize model with Automatic Mixed Precision. @@ -82,9 +85,14 @@ def initialize_model(): # the same across tensor parallelism. sync_model_param_within_tp(model) + # Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random + # state in the same dp group are all the same. + set_mode(ParallelMode.DATA) + return model +@llm_timeout(func_name="initialize_optimizer") def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): """ Initialize optimizer. @@ -122,6 +130,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): return optimizer, beta2_scheduler, lr_scheduler +@llm_timeout(func_name="get_train_data_loader") def get_train_data_loader( num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None ): @@ -201,6 +210,7 @@ def get_train_data_loader( return train_dl, dataset_types +@llm_timeout(func_name="get_validation_data_loader") def get_validation_data_loader( num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None ): @@ -262,6 +272,7 @@ def get_validation_data_loader( return val_dls +@llm_timeout(func_name="load_new_batch") def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): """ Load and return the new batch data based on training data loader. @@ -319,6 +330,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None): ) +@llm_timeout(func_name="record_current_batch_training_metrics") def record_current_batch_training_metrics( get_tflops_func, logger, @@ -342,6 +354,7 @@ def record_current_batch_training_metrics( set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time())) + timer.store_last_timers() if success_update in (0, True): train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) if is_no_pp_or_last_stage(): @@ -404,6 +417,9 @@ def record_current_batch_training_metrics( else: writer.add_scalar(key=key, value=value, step=train_state.step_count) + if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0: + send_heartbeat("train_metrics", infos) + if update_panel: # metrics shown with dashboard panels panel_metrics = { @@ -429,4 +445,8 @@ def record_current_batch_training_metrics( logger.info(line) # if loss spike occurs, send alert info to feishu - mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item()) + mm.monitor_loss_spike( + alert_address=gpc.config.monitor.alert.feishu_alert_address, + step_count=batch_count, + cur_step_loss=loss.item(), + ) diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 27ae9bd..ddb4932 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -9,7 +9,9 @@ import torch.distributed as dist from flash_attn.modules.mha import FlashSelfAttention, SelfAttention from torch.utils import benchmark +from internlm.monitor import send_alert_message from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer try: import GPUtil @@ -24,6 +26,23 @@ from internlm.utils.common import get_current_device logger = get_logger(__file__) +def empty_cache_and_diag(batch_count, interval=50): + """empty cuda cache and run diag bench or tests.""" + if interval <= 0: + interval = 50 + if batch_count % int(interval) == 0: + # there is no need to do diag on the first batch + if batch_count > 0: + if gpc.is_rank_for_log(): + logger.info("Empty Cache and Diagnosis GPU/NCCL/Timer ...") + with torch.no_grad(): + timer_diagnosis() + bench_gpu() + bench_net() + # do empty_cache after the bench + torch.cuda.empty_cache() + + def benchmark_forward( test_fn, *inputs, @@ -81,14 +100,78 @@ def get_cpu_temperature(): return cpu_temperature +def timer_diagnosis(): + """Diagnosis running time""" + + if len(timer.names) == 0 or len(timer.times) == 0: + return + + world_size = gpc.get_world_size(ParallelMode.DATA) + if world_size < 2: + return + + # if gpc.is_rank_for_log(): + # logger.info("Diagnosis running timers ...") + + # detect slow rank compared to other ranks in the same DP group + running_time = torch.Tensor(timer.times).to(device=get_current_device()) + avg_time = running_time.detach().clone() + if world_size <= 4: + dist.all_reduce(avg_time, op=torch.distributed.ReduceOp.AVG, group=gpc.get_group(ParallelMode.DATA)) + else: + running_time_max = avg_time.detach().clone() + running_time_min = avg_time.detach().clone() + dist.all_reduce(running_time_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA)) + dist.all_reduce(running_time_min, op=torch.distributed.ReduceOp.MIN, group=gpc.get_group(ParallelMode.DATA)) + dist.all_reduce(avg_time, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) + avg_time = (avg_time - running_time_max - running_time_min) / (world_size - 2) + + diag_result = running_time > avg_time * gpc.config.data.diag_outlier_ratio + diag_result = diag_result.tolist() + avg_time = avg_time.tolist() + + for slow, name, time, avg in zip(diag_result, timer.names, timer.times, avg_time): + if slow is False or avg < 0.5: + continue + msg = ( + f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} is slower than avg on {name}, " + f"Hostname {socket.gethostname()}, " + f"its time {time:.2f}, avg {avg:.2f}, " + f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}" + ) + logger.warning(msg) + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message=msg, + ) + + # detect slow rank compared to historical timer data + for name, time in zip(timer.names, timer.times): + if name not in timer.hist or len(timer.hist[name]) < 5: + continue + hist_avg = sum(timer.hist[name]) / len(timer.hist[name]) + if time > hist_avg * gpc.config.data.diag_outlier_ratio and time > 0.5: + msg = ( + f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} is slower than hist avg on {name}, " + f"Hostname {socket.gethostname()}, " + f"its time {time:.2f}, hist_avg {hist_avg:.2f}, " + f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}" + ) + logger.warning(msg) + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message=msg, + ) + + def bench_net(): """Benchmark nccl performance for slow node detection.""" if gpc.get_world_size(ParallelMode.GLOBAL) <= 1: return - if gpc.is_rank_for_log(): - logger.info("benchmarking network speed ...") + # if gpc.is_rank_for_log(): + # logger.info("benchmarking network speed ...") repeats = 100 input_data = torch.randn( @@ -113,20 +196,25 @@ def bench_net(): allreduce_time_avg = allreduce_time / gpc.get_world_size(ParallelMode.GLOBAL) allreduce_time_avg = float(allreduce_time_avg.item()) - if allreduce_time_this >= allreduce_time_avg * 1.05: - logger.warning( + if allreduce_time_this >= allreduce_time_avg * gpc.config.data.diag_outlier_ratio: + msg = ( f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} NCCL test is slower than avg, " f"Hostname {socket.gethostname()}, " f"allreduce_time {allreduce_time_this:.2f}, avg {allreduce_time_avg:.2f}, " f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}" ) + logger.warning(msg) + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message=msg, + ) def bench_gpu(use_flash_attn=True): """Benchmark single GPU performance for slow node detection.""" - if gpc.is_rank_for_log(): - logger.info("benchmarking gpu speed ...") + # if gpc.is_rank_for_log(): + # logger.info("benchmarking gpu speed ...") headdim = 64 dim = 2048 @@ -154,10 +242,15 @@ def bench_gpu(use_flash_attn=True): speed_avg = speed / gpc.get_world_size(ParallelMode.GLOBAL) speed_avg = float(speed_avg.item()) - if speed_this <= speed_avg * 0.95: - logger.warning( + if speed_this <= speed_avg / gpc.config.data.diag_outlier_ratio: + msg = ( f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} GPU is slower than avg, " f"Hostname {socket.gethostname()}, " f"tflops {speed_this:.2f}, avg {speed_avg:.2f}, " f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}" ) + logger.warning(msg) + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message=msg, + ) diff --git a/internlm/utils/logger.py b/internlm/utils/logger.py index 679913a..6111553 100644 --- a/internlm/utils/logger.py +++ b/internlm/utils/logger.py @@ -84,7 +84,7 @@ def initialize_uniscale_logger( job_name and launch_time and file_name ), "If file_path is None, job_name, launch_time and file_name must be setted." log_file_name = file_name - log_folder = os.path.join(job_name, launch_time, "logs") + log_folder = os.path.join("RUN", job_name, launch_time, "logs") log_dir = os.path.join(log_folder, log_file_name) file_path = log_dir diff --git a/internlm/utils/megatron_timers.py b/internlm/utils/megatron_timers.py index e319a80..d5d89e5 100644 --- a/internlm/utils/megatron_timers.py +++ b/internlm/utils/megatron_timers.py @@ -16,8 +16,12 @@ class _Timer: self.start_time = time.time() self.stream = torch.cuda.current_stream() - def start(self): + def start(self, reset_all=True): """Start the timer.""" + # need to reset all timers in a new batch + if self.name_ == "one-batch" and reset_all is True: + megatron_timer.reset() + assert not self.started_, "timer has already been started" self.stream.synchronize() self.start_time = time.time() @@ -48,7 +52,7 @@ class _Timer: self.reset() # If timing was in progress, set it back. if started_: - self.start() + self.start(reset_all=False) return elapsed_ @@ -57,12 +61,29 @@ class Timers: def __init__(self): self.timers = {} + self.hist = {} + self.names = [] + self.times = [] def __call__(self, name): if name not in self.timers: self.timers[name] = _Timer(name) return self.timers[name] + def store_last_timers(self): + """Store timers to two list""" + self.names = [] + self.times = [] + for key, value in self.timers.items(): + senconds = round(float(value.elapsed(reset=False)), 4) + self.names.append(key) + self.times.append(senconds) + if key not in self.hist: + self.hist[key] = [] + self.hist[key].append(senconds) + if len(self.hist[key]) > 10: + self.hist[key].pop(0) + def write(self, names, writer, iteration, normalizer=1.0, reset=False): """Write timers to a tensorboard writer""" # currently when using add_scalars, diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 09bafa5..b8f7ad6 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -3,37 +3,136 @@ import copy import fcntl +import inspect import os import socket import time from enum import Enum -from typing import Dict +from typing import Callable, Dict, Union import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState +from internlm.initialize.launch import get_config_value +from internlm.initialize.legacy.launch import ( + auto_resume_sanity_check, + ckpt_info_sanity_check, +) from internlm.monitor import send_alert_message -from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.solver.optimizer import HybridZeroOptimizer, reload_zero_fp32_buff from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.storage_manager import ( get_fns, get_storage_manager, + init_storage_manager, llm_load, llm_save, + try_get_storage_backend, ) +from internlm.utils.timeout import llm_timeout logger = get_logger(__file__) -class CheckpointType(Enum): +class CheckpointSaveType(Enum): NORMAL_CHECKPOINT = 1 SNAPSHOT_CHECKPOINT = 2 +class CheckpointLoadType(Enum): + INTERNLM = "internlm" + + +# The load method implemented by internlm by default does not use string representation types, +# but uses enumeration types defined in advance. +LOAD_TYPE_DICT = { + "internlm": CheckpointLoadType.INTERNLM, +} + + +class CheckpointLoadContent: + MODEL = "model" + SAMPLER = "sampler" + OPIMIZER = "optimizer" + SCHEDULAER = "scheduler" + + +class CheckpointLoadMethod: + """The registration class of the checkpoint loading method, + users can define their own custom ckpt loading methods.""" + + LOAD_FUNC_SIG = None + LOAD_TYPE_FUNC = {} + + @staticmethod + def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]: + if load_type.lower() in LOAD_TYPE_DICT: + # The ckpt load method implemented by internlm by default. + return LOAD_TYPE_DICT[load_type.lower()] + else: + # If it is a user-defined field, we do not do any conversion and represent it as a string. + return load_type + + @staticmethod + def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable): + if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC: + logger.warning(f"{load_type} has aleady been registed!") + return + + CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func}) + + if load_type == CheckpointLoadType.INTERNLM: + CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func) + else: + if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG: + logger.warning( + f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}" + ) + + @staticmethod + def get_ckpt_load_type_func(load_type: Union[str, CheckpointLoadType]): + return CheckpointLoadMethod.LOAD_TYPE_FUNC[load_type] + + +class CheckpointLoadMask: + """ + According to the content field in the incoming ckpt_info, decide which components to load. + """ + + LOAD_CONTENT_DICT = { + "model": CheckpointLoadContent.MODEL, + "sampler": CheckpointLoadContent.SAMPLER, + "optimizer": CheckpointLoadContent.OPIMIZER, + "scheduler": CheckpointLoadContent.SCHEDULAER, + } + + def __init__(self, content: tuple) -> None: + self.load_set = set(map(lambda x: x.lower(), content)) + if "all" in self.load_set: + self.load_set = set(CheckpointLoadMask.LOAD_CONTENT_DICT.values()) + else: + self.load_set = set(map(lambda x: CheckpointLoadMask.LOAD_CONTENT_DICT[x.lower()], content)) + + def need_load(self, content: CheckpointLoadContent): + return content in self.load_set + + def not_only_load(self, content: CheckpointLoadContent): + return content in self.load_set and len(self.load_set) > 1 + + def only_load(self, content: CheckpointLoadContent): + return set((content,)) == self.load_set + + def __str__(self) -> str: + return f"{self.load_set}." + + def __repr__(self) -> str: + return f"{self.load_set}." + + def get_model_topology(model): """ Returns: @@ -55,6 +154,66 @@ def get_model_topology(model): return topos +def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): + load_content_str = "" + load_ckpt_folder = load_info["path"] + load_content: CheckpointLoadMask = load_info["content"] + + if gpc.is_rank_for_log(): + logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") + + if load_content.need_load(CheckpointLoadContent.MODEL): + load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) + load_content_str += f"{CheckpointLoadContent.MODEL}, " + + if load_content.not_only_load(CheckpointLoadContent.MODEL): + # load training states. + load_context(load_ckpt_folder, train_state) + + # load optimzier states. + if load_content.need_load(CheckpointLoadContent.OPIMIZER): + load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer) + load_content_str += f"{CheckpointLoadContent.OPIMIZER}, " + else: + if gpc.is_rank_for_log(): + logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!") + + # load lr scheduler states. + if load_content.need_load(CheckpointLoadContent.SCHEDULAER): + if ckpt_mm.lr_scheduler: + load_scheduler(load_ckpt_folder, ckpt_mm.lr_scheduler, ckpt_mm.optimizer, train_state) + load_content_str += f"{CheckpointLoadContent.SCHEDULAER}, " + else: + if gpc.is_rank_for_log(): + logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!") + + # load dataloader sampler states. + if load_content.need_load(CheckpointLoadContent.SAMPLER): + if hasattr(train_state, "batch_sampler") and not isinstance( + train_state.batch_sampler, torch.utils.data.sampler.BatchSampler + ): + load_sampler(load_ckpt_folder, ckpt_mm.train_dl.batch_sampler) + # track the actual updates of sampler when using weighted sampling + train_state.init_batch_sampler(ckpt_mm.train_dl.batch_sampler) + load_content_str += f"{CheckpointLoadContent.SAMPLER}, " + else: + if gpc.is_rank_for_log(): + logger.warning("CheckpointManager skip reload 'batch_sampler'") + + # reload data state dict. + if hasattr(train_state, "data_state_dict"): + ckpt_mm.train_dl.dataset.load_state_dict( + llm_load(os.path.join(load_ckpt_folder, "sampler_0.pt")), ckpt_path=load_ckpt_folder + ) + load_content_str += f"{CheckpointLoadContent.SAMPLER}, " + else: + if gpc.is_rank_for_log(): + logger.warning( + "CheckpointManager has no 'data_state_dict', skip reload data_state_dict checkpoint!" + ) + return load_content_str + + def save_model_checkpoint(folder, model): """ Save the model according to the relationship between tp and dp. The principle is that the data of each tp @@ -233,15 +392,16 @@ def load_sampler(ckpt_path: str, sampler): torch.cuda.empty_cache() -def load_context(ckpt_path: str, train_dl, train_state: TrainState): +def load_context(ckpt_path: str, train_state: TrainState): context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt")) - train_state.load_state_dict(context_stuffs, train_dl) + train_state.load_state_dict(context_stuffs) if gpc.is_rank_for_log(): logger.info(f"reload train_state:{train_state}") torch.cuda.empty_cache() -def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState): +def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainState): + learning_rate = train_state.lr scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt")) if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log(): logger.warning( @@ -270,7 +430,17 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train class CheckpointManager: """StorageManagerContext""" - def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None: + def __init__( + self, + ckpt_config, + model, + train_dl=None, + optimizer=None, + lr_scheduler=None, + model_config=None, + model_config_file=None, + feishu_address=None, + ) -> None: """ CheckpointManager is used to decide when to store ckpt. If it is an asynchronous upload mode, you must call wait_async_upload_finish at the end of the program to wait @@ -283,22 +453,44 @@ class CheckpointManager: lr_scheduler (object): lr_scheduler obj. model_config (dict): model config. """ - self.enable_save_ckpt = ckpt_config.enable_save_ckpt - self.checkpoint_every = ckpt_config.checkpoint_every - self.save_ckpt_folder = ckpt_config.save_ckpt_folder - self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder - self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq - self.stop_file_path = ckpt_config.stop_file_path - self.load_model_only_folder = ckpt_config.load_model_only_folder + self.enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False) + self.checkpoint_every = get_config_value(ckpt_config, "checkpoint_every", 100) + self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None) + self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50) + self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None) + if self.save_ckpt_folder: + self.snapshot_ckpt_folder = get_config_value( + ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot") + ) + self.async_upload_tmp_folder = get_config_value( + ckpt_config, "async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/" + ) + else: + self.snapshot_ckpt_folder = None + self.async_upload_tmp_folder = None + + self.async_upload = get_config_value(ckpt_config, "async_upload", False) + + # initialization storage manager + init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload) + self.feishu_address = feishu_address self.storage_manager = get_storage_manager() self.snapshot_counter = 0 - self.load_optimizer = gpc.config.ckpt.load_optimizer self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.train_dl = train_dl self.model_config = model_config self.model_config_file = model_config_file + # Register defalut internlm ckpt load type. + self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt} + for ckpt_load_type in CheckpointLoadType: + CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type]) + + # Init alter file. if self.stop_file_path and gpc.get_global_rank() == 0: dir_path = os.path.dirname(self.stop_file_path) if dir_path != "" and not os.path.exists(dir_path): @@ -306,21 +498,35 @@ class CheckpointManager: with open(self.stop_file_path, "w", encoding="utf-8") as f: f.write("0") - if ckpt_config.load_given_ckpt is False: - # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder - latest_ckpt_path = self.query_lastest_ckpt() - if latest_ckpt_path: - self.load_ckpt_folder = latest_ckpt_path - else: - # At this time, we have to load model init weights and train from step 0. - self.load_ckpt_folder = self.load_model_only_folder - else: - self.load_ckpt_folder = ckpt_config.load_ckpt_folder + self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None) + if self.load_ckpt_info is None: # (legacy): Try Compatible with old interfaces + self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config) - if gpc.is_rank_for_log(): - logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'") - if self.stop_file_path is None: - logger.warning("no set stop_file_path, quit_signal_handler is disable") + # Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'. + self.auto_resume = get_config_value(ckpt_config, "auto_resume", None) + if self.auto_resume is None: # (legacy): Try Compatible with old interfaces + self.auto_resume = auto_resume_sanity_check(ckpt_config) + if self.auto_resume: + self.load_ckpt_info = self.query_lastest_ckpt() + + if self.stop_file_path is None and gpc.is_rank_for_log(): + logger.warning("no set stop_file_path, quit_signal_handler is disable") + + # convert to internal representation + if self.load_ckpt_info: + assert ( + "path" in self.load_ckpt_info + and "content" in self.load_ckpt_info + and "ckpt_type" in self.load_ckpt_info + ), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')" + + # replace load_ckpt + self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) + self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"]) + + # test storage setting is ok. + if self.enable_save_ckpt: + self.try_ping_storage() def quit_signal_handler(self, train_state) -> bool: """ @@ -334,7 +540,7 @@ class CheckpointManager: Returns: bool: whether to quit. """ - now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT + now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT if self.stop_file_path is None: return now_break, now_save_ckpt, save_type @@ -365,24 +571,29 @@ now step_count is {train_state.step_count}", return now_break, now_save_ckpt, save_type - def try_save_checkpoint(self, train_state): - if not self.enable_save_ckpt: - return False - - save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT + def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool): + save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: - save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT + save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0: - save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT + save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: save_ckpts = singal_save_ckpts save_type = singal_save_type + return save_ckpts, save_type, now_break + + def try_save_checkpoint(self, train_state): + if not self.enable_save_ckpt: + return False + + save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state) + if save_ckpts: # Wait for the previous round of asynchronous upload storage to complete. self.storage_manager.wait() - if save_type == CheckpointType.SNAPSHOT_CHECKPOINT: + if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT: # Snapshot number, with only two snapshots written alternately. self.snapshot_counter = (self.snapshot_counter + 1) % 2 save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}") @@ -412,7 +623,7 @@ now step_count is {train_state.step_count}", Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. """ ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder) - if len(ckpt_list) == 0: + if ckpt_list is None or len(ckpt_list) == 0: return None, None max_normal_step = 0 @@ -435,14 +646,16 @@ now step_count is {train_state.step_count}", ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0) ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1) max_step_0, max_step_1 = 0, 0 - for ckpt in ckpt_list_1: - ckpt = ckpt.strip("/") - if ckpt.endswith(".step"): - max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) - for ckpt in ckpt_list_2: - ckpt = ckpt.strip("/") - if ckpt.endswith(".step"): - max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) + if ckpt_list_1: + for ckpt in ckpt_list_1: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) + if ckpt_list_2: + for ckpt in ckpt_list_2: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 snap_step = max(max_step_0, max_step_1) @@ -452,11 +665,12 @@ now step_count is {train_state.step_count}", def query_latest_snapshot_step_local(self): max_step, max_step_path = 0, None - for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True): + save_ckpt_folder = self.save_ckpt_folder.split(":")[1] + for root, _, files in os.walk(save_ckpt_folder, followlinks=True): for fn in files: fn = fn.strip("/") if fn.endswith(".step"): - # We assume that both normal ckpt and snapshot ckpt will store the '.step' file + # We assume that both internlm ckpt and snapshot ckpt will store the '.step' file # as an integrity flag. step = int(fn.rsplit(".", maxsplit=1)[0]) if max_step < step: @@ -466,100 +680,55 @@ now step_count is {train_state.step_count}", return max_step_path, max_step def query_lastest_ckpt(self): - latest_checkpoint = None + latest_ckpt, step = None, -1 # Training was automatically restarted by the process, forcing the latest snapshot to be read. if self.save_ckpt_folder: - if self.save_ckpt_folder.startswith("boto3"): - latest_checkpoint, step = self.query_latest_snapshot_step_boto3() - elif self.save_ckpt_folder.startswith("local"): - latest_checkpoint, step = self.query_latest_snapshot_step_local() - else: - latest_checkpoint, step = None, 0 + backend, _ = try_get_storage_backend(self.save_ckpt_folder) + if backend == "boto3": + latest_ckpt, step = self.query_latest_snapshot_step_boto3() + if latest_ckpt and not latest_ckpt.startswith("boto3:"): + latest_ckpt = ":".join(["boto3", latest_ckpt]) + elif backend == "local": + latest_ckpt, step = self.query_latest_snapshot_step_local() + if latest_ckpt and not latest_ckpt.startswith("local:"): + latest_ckpt = ":".join(["local", latest_ckpt]) - if latest_checkpoint is not None: - if gpc.is_rank_for_log(): - logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}") - send_alert_message( - address=self.feishu_address, - message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}", - ) - else: - if gpc.is_rank_for_log(): - send_alert_message( - address=self.feishu_address, - message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}", - ) + if gpc.is_rank_for_log(): + logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...") - return latest_checkpoint + return dict(path=latest_ckpt, content=("all",), ckpt_type="internlm") - def try_load_model(self, current_time=""): - model_load_path = None + def try_resume_training(self, train_state: TrainState, current_time=""): - if self.load_ckpt_folder and self.load_model_only_folder: - raise ValueError( - "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \ -if you only need to load model weights (for example starting an SFT task for the first time), \ -set load_model_only_folder path, if you need to resume training from ckpt, \ -set load_ckpt_folder or use default value \ -(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)" - ) - - if self.load_ckpt_folder: - if gpc.is_rank_for_log(): - logger.info( - f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = self.load_ckpt_folder - elif self.load_model_only_folder: - if gpc.is_rank_for_log(): - logger.info( - f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = self.load_model_only_folder - else: + if self.load_ckpt_info is None or self.load_ckpt_info["path"] is None: if gpc.is_rank_for_log(): logger.info( f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) + else: + load_path = self.load_ckpt_info["path"] + load_content = self.load_ckpt_info["content"] + load_type = self.load_ckpt_info["ckpt_type"] - # Loading model weights must be done before zero is initialized. - if model_load_path is not None: - load_model_checkpoint(folder=model_load_path, model=self.model) + load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type) + load_content_str = load_func(self, self.load_ckpt_info, train_state) - def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl): - """Attempt to restore the training state of the last ckpt. + # If we only load model weight, we need rewrite zero optim's fp32 buffer. + if load_content.only_load(CheckpointLoadContent.MODEL) and isinstance(self.optimizer, HybridZeroOptimizer): + reload_zero_fp32_buff(self.optimizer) - Args: - lr_scheduler (_LRScheduler): lr_scheduler object. - optimizer (Optimizer): optimizer object. - lr (float): learning rate. - train_state (dict): traing states. - train_dl (DataLoader): traning dataloader object - """ - if self.load_ckpt_folder is not None: - # load optimzier states. - if self.load_optimizer: - load_optimizer_checkpoint(self.load_ckpt_folder, optimizer) - # load lr scheduler states. - load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state) - # load training states. - load_context(self.load_ckpt_folder, train_dl, train_state) - # load dataloader sampler states. - if hasattr(train_state, "batch_sampler") and not isinstance( - train_state.batch_sampler, torch.utils.data.sampler.BatchSampler - ): - load_sampler(self.load_ckpt_folder, train_dl.batch_sampler) - if hasattr(train_state, "data_state_dict"): - train_dl.dataset.load_state_dict( - llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder + if gpc.is_rank_for_log(): + logger.info(f"load_ckpt_info : {self.load_ckpt_info}") + logger.info( + f"===========Resume training from `{load_path}` {current_time} on host:" + f"{socket.gethostname()}===========" ) - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler + if load_content_str: + logger.info(f"===========Load contents are: {load_content_str}") + @llm_timeout(func_name="save_checkpoint") def save_checkpoint( self, folder, @@ -600,8 +769,10 @@ set load_ckpt_folder or use default value \ ) if gpc.is_rank_for_log(): - scheduler_states = scheduler.state_dict() - llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + if scheduler: + scheduler_states = scheduler.state_dict() + llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + if hasattr(train_state, "batch_sampler") and not isinstance( train_state.batch_sampler, torch.utils.data.sampler.BatchSampler ): @@ -631,3 +802,12 @@ set load_ckpt_folder or use default value \ def set_save_folder(self, folder, step): self.storage_manager.latest_save_folder = folder self.storage_manager.latest_save_step = step + + def try_ping_storage(self): + if gpc.get_global_rank() % 8 == 0: + buff = torch.ones((1, 64, 64), dtype=torch.bfloat16) + test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping") + self.storage_manager.save(test_fn, buff) + self.storage_manager.wait() + self.storage_manager.load(test_fn) + del buff diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index c7b71f4..36bd105 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -46,12 +46,12 @@ def get_fns(fp: str): return storage_manager.get_fns(fp) -def llm_load(fp: str, *args, **kwargs): - return storage_manager.load(fp, *args, **kwargs) +def llm_load(fp: str, **kwargs): + return storage_manager.load(fp, **kwargs) -def llm_save(save_path: str, saved_obj: Any, *args, **kwargs): - storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs) +def llm_save(save_path: str, saved_obj: Any, **kwargs): + storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs) class StorageClient: @@ -63,19 +63,23 @@ class StorageClient: self.handler = handler @staticmethod - def load(client, load_path: str, *args, **kwargs): + def load(*args, **kwargs): raise NotImplementedError @staticmethod - def sync_upload_fileobj(*args, saved_obj=None, **kwargs): + def sync_upload_fileobj(*args, **kwargs): raise NotImplementedError @staticmethod - def assert_fp_exists(client): + def async_upload_fileobj(*args, **kwargs): raise NotImplementedError @staticmethod - def get_fns(client): + def assert_fp_exists(*args, **kwargs): + raise NotImplementedError + + @staticmethod + def get_fns(*args, **kwargs): raise NotImplementedError @@ -92,40 +96,65 @@ class Boto3MetaInfo: async_upload_fn: callable, local_nvme_path=None, ) -> None: - self.is_async = is_async + # all need info. self.client = handler self.bucket_name = bucket_name - self.endpoint = endpoint self.file_path = file_path - self.async_upload_fn = async_upload_fn + # only save need info. self.local_nvme_path = local_nvme_path + self.is_async = is_async + self.endpoint = endpoint + self.async_upload_fn = async_upload_fn def __str__(self) -> str: return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \ local_nvme_path: {self.local_nvme_path}" + @staticmethod + def unpack_boto3_save_meta(meta): + if meta.is_async: + return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path + else: + return meta.client, meta.bucket_name, meta.file_path + + @staticmethod + def unpack_boto3_nosave_meta(meta): + return meta.client, meta.bucket_name, meta.file_path + class LocalMetaInfo: """Local meta info for save/load etc.""" - def __init__(self, handler: StorageClient, dest_path: str) -> None: - self.is_async = False - self.client = handler - self.dest_path = dest_path + def __init__(self, file_path: str) -> None: + self.file_path = file_path self.async_upload_fn = None + self.is_async = False + + @staticmethod + def unpack_local_save_meta(meta): + return (meta.file_path,) + + @staticmethod + def unpack_local_nosave_meta(meta): + return (meta.file_path,) -def unpack_meta(meta): - args = [] - is_async = meta.is_async - for k, v in meta.__dict__.items(): - if k in ("endpoint", "async_upload_fn", "is_async"): - continue - if not is_async and k in ("local_nvme_path",): - continue - args.append(v) +def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]): + if isinstance(meta, Boto3MetaInfo): + return Boto3MetaInfo.unpack_boto3_save_meta(meta) + elif isinstance(meta, LocalMetaInfo): + return LocalMetaInfo.unpack_local_save_meta(meta) + else: + raise ValueError(f"unkonwn meta info: {type(meta)}") - return args + +def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]): + if isinstance(meta, Boto3MetaInfo): + return Boto3MetaInfo.unpack_boto3_nosave_meta(meta) + elif isinstance(meta, LocalMetaInfo): + return LocalMetaInfo.unpack_local_nosave_meta(meta) + else: + raise ValueError(f"unkonwn meta info: {type(meta)}") def compute_file_md5_by_chunk(file_name: str): @@ -136,6 +165,22 @@ def compute_file_md5_by_chunk(file_name: str): return hash_md5.hexdigest() +def try_get_storage_backend(path: str): + sre = path.split(":", maxsplit=1) + if len(sre) == 1: + if path.startswith("s3:"): + backend = "boto3" + if gpc.is_rank_for_log(): + logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") + else: + backend = "local" + if gpc.is_rank_for_log(): + logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.") + return backend, sre + else: + return sre[0], sre[1] # (backend_prefix, splited_path) + + class Boto3Client(StorageClient): """ Boto3Client @@ -189,13 +234,11 @@ class Boto3Client(StorageClient): ) @staticmethod - def sync_upload_fileobj( - handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs - ): # pylint: disable=W0613 + def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs): assert saved_obj is not None, "saved_obj is None!" try: with io.BytesIO() as f: - torch.save(saved_obj, f, *args, **kwargs) + torch.save(saved_obj, f, **kwargs) f.seek(0) handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config) except handler.botocore.exceptions.EndpointConnectionError as exc: @@ -204,14 +247,7 @@ class Boto3Client(StorageClient): ) from exc @staticmethod - def load( - handler, - bucket_name: str, - fp: str, - local_nvme_path: str, # pylint: disable=W0613 - *args, - **kwargs, - ) -> Dict: + def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict: """ Args: fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt @@ -220,7 +256,7 @@ class Boto3Client(StorageClient): with io.BytesIO() as f: handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config) f.seek(0) - states = torch.load(f, *args, **kwargs) + states = torch.load(f, **kwargs) except handler.botocore.exceptions.EndpointConnectionError as exc: raise RuntimeError( f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}" @@ -228,24 +264,37 @@ class Boto3Client(StorageClient): return states @staticmethod - def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613 + def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613 assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp @staticmethod - def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613 + def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613 + re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp) + if "Contents" in re: + return len(list(re["Contents"])) > 0 + else: + return False + + @staticmethod + def get_fns(handler, bucket_name: str, fp: str): """ Ref: https://stackoverflow.com/questions/54314563/ how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2 """ - paginator = handler.client.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) - folder_name_list = [] - for page in pages: - if "Contents" in page: - for obj in page["Contents"]: - pth: str = obj["Key"] - folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) - return list(set(folder_name_list)) + if Boto3Client.is_fp_exists(handler, bucket_name, fp): + paginator = handler.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) + folder_name_list = [] + for page in pages: + if "Contents" in page: + for obj in page["Contents"]: + pth: str = obj["Key"] + folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) + return list(set(folder_name_list)) + else: + if gpc.is_rank_for_log(): + logger.warning(f"'{fp}' not found!") + return None @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): @@ -273,37 +322,35 @@ class LocalClient(StorageClient): super().__init__(None) @staticmethod - def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs): - assert isinstance(handler, LocalClient) + def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs): assert saved_obj is not None fp_dirname = os.path.dirname(fp) if not os.path.exists(fp_dirname): os.makedirs(fp_dirname, exist_ok=True) - torch.save(saved_obj, fp, *args, **kwargs) + torch.save(saved_obj, fp, **kwargs) @staticmethod - def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613 - assert isinstance(handler, LocalClient) - assert os.path.exists(fp), f"{fp} is not found!" - with open(fp, "rb") as f: - states = torch.load(f, *args, **kwargs) + def load(load_path: str, **kwargs): + assert os.path.exists(load_path), f"{load_path} is not found!" + with open(load_path, "rb") as f: + states = torch.load(f, **kwargs) return states @staticmethod - def assert_fp_exists(handler, folder): - assert isinstance(handler, LocalClient) + def assert_fp_exists(folder): assert os.path.exists(folder), folder @staticmethod - def get_fns(handler, folder): - assert isinstance(handler, LocalClient) - assert os.path.exists(folder), f"folder '{folder}' not exists!" - fns = os.listdir(folder) - return fns + def get_fns(folder): + if not os.path.exists(folder): + if gpc.is_rank_for_log(): + logger.warning(f"'{folder}' not found!") + return None + else: + return os.listdir(folder) @staticmethod - def delete_obj(handler, fp: str): - assert isinstance(handler, LocalClient) + def delete_obj(fp: str): if not os.path.isdir(fp): os.remove(fp) @@ -327,7 +374,10 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI assert match is not None, f"url '{fp}' is not a valid boto3 url" bucket_name, endpoint = match.group(1), match.group(2) endpoint = "http://" + endpoint + ":80" - tmp_step_file = get_tmp_file_name(tmp_local_folder, fp) + if is_async: + tmp_step_file = get_tmp_file_name(tmp_local_folder, fp) + else: + tmp_step_file = None return Boto3MetaInfo( is_async=is_async, handler=None, @@ -341,7 +391,7 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI def get_local_meta(fp: str) -> LocalMetaInfo: assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path" - return LocalMetaInfo(None, fp) + return LocalMetaInfo(fp) def get_mount_point_free_size(path: str): @@ -427,7 +477,7 @@ class StorageManager(metaclass=SingletonMeta): logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!') raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}") - def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]: + def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]: """ example: local:/path/to/checkpoint @@ -436,17 +486,14 @@ class StorageManager(metaclass=SingletonMeta): Args: path (str): _description_ """ - try: - backend, path = path.split(":", maxsplit=1) - except Exception as exc: - raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc + backend, path = try_get_storage_backend(path) init_args = (None,) if backend == "local": meta_info = get_local_meta(path) backend_key = backend elif backend == "boto3": - meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode) + meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode) backend_key = backend + ":" + meta_info.endpoint init_args = (meta_info.endpoint,) if ( @@ -474,17 +521,22 @@ class StorageManager(metaclass=SingletonMeta): def assert_fp_exists(self, folder) -> None: meta = self._get_client(path=folder) - meta.client.assert_fp_exists(*unpack_meta(meta)) + meta.client.assert_fp_exists(*unpack_nosave_meta(meta)) def get_fns(self, folder) -> List[str]: meta = self._get_client(path=folder) - return meta.client.get_fns(*unpack_meta(meta)) + return meta.client.get_fns(*unpack_nosave_meta(meta)) - def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs): - meta = self._get_client(path=save_path) + def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs): if async_upload is None: async_upload = self.async_mode + + if not save_path.startswith("boto3:"): + async_upload = False + + meta = self._get_client(save_path, async_upload) + if async_upload: assert ( self.tmp_local_folder @@ -492,22 +544,22 @@ class StorageManager(metaclass=SingletonMeta): tmp_step_file = meta.local_nvme_path self._to_be_del_files.append(tmp_step_file) with open(tmp_step_file, "wb") as f: - torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) - self.async_executor(meta.async_upload_fn, *unpack_meta(meta)) + torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) + self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta)) os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) self.async_task_peeding = True else: - meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) + meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs) self.upload_count += 1 - def load(self, load_path: str, *args, **kwargs) -> Any: + def load(self, load_path: str, **kwargs) -> Any: self.wait() meta = self._get_client(path=load_path) - return meta.client.load(*unpack_meta(meta), *args, **kwargs) + return meta.client.load(*unpack_nosave_meta(meta), **kwargs) def delete_obj(self, fp: str): meta = self._get_client(path=fp) - meta.client.delete_obj(*unpack_meta(meta)) + meta.client.delete_obj(*unpack_nosave_meta(meta)) def _del_tmp_folder(self): for fp in self._to_be_del_files: @@ -594,23 +646,24 @@ class StorageManager(metaclass=SingletonMeta): if gpc.is_rank_for_log(): self.upload_count += 1 - if self.async_mode: + if self.async_mode and self.latest_save_folder: self.save( os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"), - saved_obj=dict({"step": self.latest_save_step}), + to_save_obj=dict({"step": self.latest_save_step}), async_upload=False, ) + self.latest_save_folder = None storage_manager: StorageManager = None -def init_storage_manager(ckpt_config): +def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload): global storage_manager storage_manager = StorageManager( - ckpt_config.enable_save_ckpt, - tmp_local_folder=ckpt_config.async_upload_tmp_folder, - async_mode=ckpt_config.async_upload, + enable_save_ckpt, + tmp_local_folder=async_upload_tmp_folder, + async_mode=async_upload, ) diff --git a/internlm/utils/timeout.py b/internlm/utils/timeout.py index 07a0911..7a96841 100644 --- a/internlm/utils/timeout.py +++ b/internlm/utils/timeout.py @@ -1,4 +1,13 @@ +import datetime +import os import signal +import socket +import traceback +from functools import wraps + +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) class Timeout: @@ -24,3 +33,81 @@ class Timeout: def __exit__(self, error_type, value, traceback): signal.alarm(0) + + +ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None) + + +timeout_threshold_dict = { + "initialize_distributed_env": 120, + "nopp_forward_backward_step": 360, + "initialize_model": 10, + "initialize_optimizer": 20, + "optim_step": 30, + "get_train_data_loader": 600, + "get_validation_data_loader": 60, + "load_new_batch": 10, + "record_current_batch_training_metrics": 10, + "save_checkpoint": 1200, + "interleaved_forward_backward_step": 600, + "nointerleaved_forward_backward_step": 600, +} + +if ENABLE_TIMEOUT is not None: + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60)))) +else: + timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0) + LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800) + + +def try_get_gpc_rank(): + try: + from internlm.core.context import global_context as gpc + + rank = gpc.get_global_rank() + except: # noqa # pylint: disable=bare-except + rank = "unknown" + + return f"host-{socket.gethostname()}-rank-{rank}" + + +def llm_timeout(seconds=0, func_name=None): + """timeout decorator, Note that this decorator cannot be reentrant, + otherwise the signal will be reset. + + Args: + seconds (int, optional): timeout threshold. Defaults to 300. + func_name (str, optional): the func who is been waited to timeout. + """ + + def decorator(func): + nonlocal func_name + if func_name is None: + func_name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + def _handle_timeout(signum, frame): + raise TimeoutError + + nonlocal seconds + seconds = timeout_threshold_dict.get(func_name, seconds) + + if seconds > 0: + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(seconds) + + try: + result = func(*args, **kwargs) + except TimeoutError as e: + logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}") + raise e + finally: + signal.alarm(0) + + return result + + return wrapper + + return decorator diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py new file mode 100644 index 0000000..d6a19b6 --- /dev/null +++ b/tests/test_utils/common_fixture.py @@ -0,0 +1,181 @@ +import os +import shutil +from subprocess import PIPE, STDOUT, Popen + +import pytest +import torch + +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer +from internlm.utils.common import SingletonMeta + +OSS_NAME = os.environ["OSS_BUCKET_NAME"] +OSS_IP = os.environ["OSS_IP"] +USER = os.environ["USER"] +JOB_NAME = "CI_TEST" +LOCAL_SAVE_PATH = "local:local_ckpt" + +BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" +BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + +ASYNC_TMP_FOLDER = "./async_tmp_folder" + + +# 1B +init_config = Config( + dict( + parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1), + model_type="INTERNLM", + adam=dict( + lr=1e-4, + ), + data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999), + model=dict( + checkpoint=False, + num_attention_heads=2, + embed_split_hidden=True, + vocab_size=103168, + embed_grad_scale=1, + parallel_output=True, + hidden_size=1024, + num_layers=2, + mlp_ratio=1, + apply_post_layer_norm=False, + dtype=torch.bfloat16, + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, + ), + resume_tb_folder="", + tensorboard_folder="", + ) +) + + +def init_naive_model(): + # let MODEL_INITIALIZER to work + import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import + from internlm.core.naive_amp import NaiveAMPModel + from internlm.utils.registry import MODEL_INITIALIZER + + model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model)) + model = NaiveAMPModel( + model=model, + output_to_fp32=False, + dtype=torch.bfloat16, + sync_buffer=False, + ) + return model + + +def init_naive_optim(model): + naive_optimizer = torch.optim.AdamW( + params=[{"params": model.parameters(), "weight_decay": 0.01}], + lr=1e-4, + betas=(0.9, 0.95), + eps=1e-8, + ) + return naive_optimizer + + +def init_hybrid_optim(model): + naive_optimizer = torch.optim.AdamW( + params=[{"params": model.parameters(), "weight_decay": 0.01}], + lr=1e-4, + betas=(0.9, 0.95), + eps=1e-8, + ) + optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=Config( + dict( + fp16=dict( + initial_scale=2**16, + min_scale=1, + growth_interval=1000, + ), + growth_factor=2, + backoff_factor=0.5, + max_scale=2**24, + hysteresis=2, + ) + ), + zero_cfg=Config( + dict( + overlap_sync_grad=False, + overlap_sync_param=False, + reduce_bucket_size=512 * 1024 * 1024, + clip_grad_norm=1.0, + ) + ), + param_bcast_sync_handler=None, + ) + return optimizer + + +@pytest.fixture(autouse=True, scope="function") +def reset_singletons(): + SingletonMeta._instances = {} + + +def reset_seed(): + from internlm.core.context.random import _SEED_MANAGER + + _SEED_MANAGER.reset() + + +@pytest.fixture(scope="module") +def init_dist_and_model(rank=0, world_size=1): + from internlm.initialize import initialize_distributed_env + + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12377" + initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) + + # setup + print("set up", flush=True) + model = init_naive_model() + # opim = init_naive_optim(model) + opim = init_hybrid_optim(model) + + yield model, opim + + # teardown + del model, opim + print("teardown", flush=True) + gpc.destroy() + reset_seed() + + +def enter_flag(text): + print(f"{text} begin!", flush=True) + yield + print(f"{text} end!", flush=True) + + +def del_tmp_file(): + try: + shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True) + except FileNotFoundError: + pass + + try: + shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True) + except FileNotFoundError: + pass + + try: + cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / " + with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output: + results, presults = "", "" + for line in iter(output.stdout.readline, b""): + results += str(line.rstrip()) + presults += line.rstrip().decode() + "\n" + print(presults, flush=True) + except FileNotFoundError: + pass diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py new file mode 100644 index 0000000..bd93436 --- /dev/null +++ b/tests/test_utils/test_model_checkpoint.py @@ -0,0 +1,247 @@ +import os + +import pytest +import torch + +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.core.trainer import TrainState +from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer +from internlm.utils.common import SingletonMeta +from internlm.utils.model_checkpoint import CheckpointManager +from internlm.utils.storage_manager import wait_async_upload_finish +from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import + ASYNC_TMP_FOLDER, + BOTO_SAVE_PATH, + LOCAL_SAVE_PATH, + del_tmp_file, + init_dist_and_model, + reset_singletons, +) + +TOTAL_STEP = 6 + +CKPT_EVERY = 4 +SNPASHOT_EVERY = 2 + + +ckpt_config_list = [ + # Old interface format + dict( + enable_save_ckpt=True, + save_ckpt_folder=BOTO_SAVE_PATH, + load_optimizer=True, + checkpoint_every=CKPT_EVERY, + async_upload=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), + oss_snapshot_freq=SNPASHOT_EVERY, + stop_file_path=None, + load_model_only_folder=None, + load_given_ckpt=False, + load_ckpt_folder=None, + is_old_api=True, + ), + # Old interface format + dict( + enable_save_ckpt=True, + save_ckpt_folder=LOCAL_SAVE_PATH, + load_optimizer=True, + checkpoint_every=CKPT_EVERY, + async_upload=False, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]), + oss_snapshot_freq=SNPASHOT_EVERY, + stop_file_path=None, + load_model_only_folder=None, + load_given_ckpt=False, + load_ckpt_folder=None, + is_old_api=True, + ), + # New interface format + dict( + enable_save_ckpt=True, + save_ckpt_folder=BOTO_SAVE_PATH, + checkpoint_every=CKPT_EVERY, + async_upload=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + oss_snapshot_freq=SNPASHOT_EVERY, + stop_file_path=None, + is_old_api=False, + auto_resume=True, + ), + dict( + enable_save_ckpt=True, + save_ckpt_folder=LOCAL_SAVE_PATH, + checkpoint_every=CKPT_EVERY, + async_upload=False, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + oss_snapshot_freq=SNPASHOT_EVERY, + stop_file_path=None, + load_ckpt_folder=None, + is_old_api=False, + auto_resume=True, + ), +] + + +def overwrite_optim_state(optim, set_value): + if isinstance(optim, HybridZeroOptimizer): + for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items(): + if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + # p.copy_(torch.full_like(p, set_value, dtype=p.dtype)) + p.data.fill_(set_value) + for group_id in range(len(optim._fp16_param_groups)): + if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group( + rank=optim._zero_local_rank, group_id=group_id + ) + fp16_p.fill_(set_value) + else: + for group in optim.param_groups: + for p in group["params"]: + # p.copy_(torch.full_like(p, set_value, dtype=p.dtype)) + p.data.fill_(set_value) + + +def compare_optim_state(optim1, optim2): + re = True + if isinstance(optim1, HybridZeroOptimizer): + fp32_buff1 = optim1._fp32_flat_param_groups_of_current_rank + fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank + for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2): + re &= group_id_1 == group_id_2 + if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]: + re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2]) + else: + for group1, group2 in zip(optim1.param_groups, optim2.param_groups): + for p1, p2 in zip(group1["params"], group2["params"]): + re &= torch.equal(p1, p2) + return re + + +def compare_optim_value(optim, value): + re = True + if isinstance(optim, HybridZeroOptimizer): + for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items(): + if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype)) + for group_id in range(len(optim._fp16_param_groups)): + if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group( + rank=optim._zero_local_rank, group_id=group_id + ) + re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype)) + else: + for group in optim.param_groups: + for p in group["params"]: + re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype)) + return re + + +def overwrite_model_value(model, value): + for p in model.parameters(): + # p.copy_(torch.full_like(p, value, dtype=p.dtype)) + p.data.fill_(value) + + +def compare_model_value(model, value): + re = True + for p in model.parameters(): + re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype)) + return re + + +@pytest.fixture(scope="function") +def del_tmp(): + del_tmp_file() + yield + del_tmp_file() + + +@pytest.mark.usefixtures("del_tmp") +@pytest.mark.usefixtures("reset_singletons") +@pytest.mark.parametrize("ckpt_config", ckpt_config_list) +def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import + from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType + + ckpt_config = Config(ckpt_config) + assert ckpt_config.checkpoint_every < TOTAL_STEP + assert ckpt_config.oss_snapshot_freq < TOTAL_STEP + + model, opim = init_dist_and_model + train_state = TrainState(gpc.config, None) + if isinstance(opim, HybridZeroOptimizer): + print("Is HybridZeroOptimizer!", flush=True) + else: + print("Is naive Adam!", flush=True) + + ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) + latest_ckpt_step = None + for i in range(TOTAL_STEP + 1): + overwrite_model_value(model, i) + overwrite_optim_state(opim, i) + + train_state.batch_count = i + train_state.step_count += 1 + + save_ckpts, _, _ = ckpt_mm.is_now_to_save_ckpt(train_state) + if save_ckpts: + latest_ckpt_step = i + + ckpt_mm.try_save_checkpoint(train_state) + + wait_async_upload_finish() + latest_ckpt_info = ckpt_mm.query_lastest_ckpt() + assert latest_ckpt_info is not None + latest_ckpt = latest_ckpt_info["path"] + if ckpt_mm.save_ckpt_folder.startswith("local"): + assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt + else: + assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt + + del ckpt_mm + SingletonMeta._instances = {} + ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) + ckpt_mm.try_resume_training(train_state) + assert latest_ckpt_step == 5 + assert train_state.step_count == 6 + assert train_state.batch_count == 6 + assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0] + assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0] + + if ckpt_mm.save_ckpt_folder.startswith("local:"): + ckpt_mm.load_ckpt_info = dict( + path=os.path.join(LOCAL_SAVE_PATH, "4"), + content=CheckpointLoadMask(("all",)), + ckpt_type=CheckpointLoadType.INTERNLM, + ) + else: + ckpt_mm.load_ckpt_info = dict( + path=os.path.join(BOTO_SAVE_PATH, "4"), + content=CheckpointLoadMask(("all",)), + ckpt_type=CheckpointLoadType.INTERNLM, + ) + + ckpt_mm.try_resume_training(train_state) + + assert train_state.step_count == 4 + assert train_state.batch_count == 4 + assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0] + assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0] + + +@pytest.mark.usefixtures("del_tmp") +@pytest.mark.usefixtures("reset_singletons") +@pytest.mark.parametrize("ckpt_config", ckpt_config_list) +def test_ckpt_mm_ping(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import + ckpt_config = Config(ckpt_config) + + model, opim = init_dist_and_model + SingletonMeta._instances = {} + ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) + ckpt_mm.try_ping_storage() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py new file mode 100644 index 0000000..32f905b --- /dev/null +++ b/tests/test_utils/test_storage_manager.py @@ -0,0 +1,89 @@ +import os + +import pytest +import torch + +from internlm.core.context.parallel_context import Config +from internlm.initialize.launch import get_config_value +from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import + ASYNC_TMP_FOLDER, + BOTO_SAVE_PATH, + LOCAL_SAVE_PATH, + del_tmp_file, + init_dist_and_model, + reset_singletons, +) + +ASYNC_TMP_FOLDER = "./async_tmp_folder" +ckpt_config_list = [ + # async boto + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + async_upload=True, + save_folder=BOTO_SAVE_PATH, + test_id=0, + ), + # sync local + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=None, + async_upload=False, + save_folder=LOCAL_SAVE_PATH, + test_id=1, + ), + # sync boto + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=None, + async_upload=False, + save_folder=BOTO_SAVE_PATH, + test_id=2, + ), + # async local + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + async_upload=True, + save_folder=LOCAL_SAVE_PATH, + test_id=3, + ), +] + + +@pytest.fixture(scope="function") +def del_tmp(): + del_tmp_file() + yield + del_tmp_file() + + +@pytest.mark.usefixtures("del_tmp") +@pytest.mark.usefixtures("reset_singletons") +@pytest.mark.parametrize("ckpt_config", ckpt_config_list) +def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument + from internlm.utils.storage_manager import ( + check_folder, + get_fns, + init_storage_manager, + llm_load, + llm_save, + wait_async_upload_finish, + ) + + ckpt_config = Config(ckpt_config) + enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False) + async_upload_tmp_folder = get_config_value(ckpt_config, "async_upload_tmp_folder", False) + async_upload = get_config_value(ckpt_config, "async_upload", False) + + init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload) + + tobj = torch.rand(64, 64) + save_fn = os.path.join(ckpt_config.save_folder, "test.pt") + llm_save(save_fn, tobj) + if ckpt_config.test_id == 0: + wait_async_upload_finish() + check_folder(save_fn) + assert get_fns(ckpt_config.save_folder)[0] == "test.pt" + load_obj = llm_load(save_fn, map_location="cpu") + assert 0 == ((load_obj != tobj).sum()) diff --git a/tests/test_utils/test_timeout.py b/tests/test_utils/test_timeout.py new file mode 100644 index 0000000..a3f15f9 --- /dev/null +++ b/tests/test_utils/test_timeout.py @@ -0,0 +1,119 @@ +import fcntl +import os +import time +from multiprocessing import Process + +import pytest +import torch +import torch.distributed as dist + +os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1" # noqa # pylint: disable=wrong-import-position +os.environ["NCCL_TIMEOUT"] = "5" +from internlm.utils.timeout import llm_timeout +from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import + init_config, +) + +WORLD_SIZE = 2 + + +@llm_timeout(2, "fake_timeout_func") +def fake_timeout_func(): + time.sleep(10) + + +@llm_timeout(10, "nccl_timeout_func") +def nccl_timeout_func(rank): + # see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880 + # 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication. + buff = torch.ones([64, 64]).cuda(rank) + dist.all_reduce(buff) # lazy communicator init + torch.cuda.synchronize() + if rank == 0: + dist.all_reduce(buff) + torch.cuda.synchronize() # main thread will hang at here. + else: + time.sleep(9999) + + +@llm_timeout(10, "try_file_lock") +def try_file_lock(rank, stop_file_path): + if rank == 1: + time.sleep(5) + + with open(stop_file_path, "r", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) # rank 1 hang. + if rank == 0: + time.sleep(99999) # rank 0 hang. + f.seek(0) + f.read() + fcntl.flock(f, fcntl.LOCK_UN) + + +def local_timeout(rank, _): + + try: + fake_timeout_func() + except TimeoutError as e: + print(f"local_timeout, rank:{rank}, e:{e}", flush=True) + else: + assert False, "It should timeout!" + + +def gpc_timeout(rank, world_size): + + from internlm.initialize import initialize_distributed_env + + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12377" + initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) + + try: + nccl_timeout_func(rank) + except TimeoutError as e: + print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True) + time.sleep(5) # wait rank 0 to be killed + else: + time.sleep(5) # give some time to let Watchdog kill rank 0. + assert False, "It should timeout!" + + +def file_lock_timeout(rank, _, stop_file_path): + if rank == 0: + with open(stop_file_path, "w"): + pass + try: + try_file_lock(rank, stop_file_path) + except TimeoutError as e: + print(e, flush=True) + else: + assert False, "It should timeout!" + finally: + if rank == 0: + os.remove(stop_file_path) + + +timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")] + + +@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list) +def test_timeout(timeout_func_and_args): + timeout_func, world_size, other_args = timeout_func_and_args + procs = [] + for i in range(world_size): + if other_args is None: + args = (i, world_size) + else: + args = (i, world_size, other_args) + proc = Process(target=timeout_func, args=args) + proc.start() + procs.append(proc) + + for proc in procs: + proc.join(15) + if proc.is_alive(): + proc.terminate() + proc.join() diff --git a/train.py b/train.py index 69cdd3c..ff15354 100644 --- a/train.py +++ b/train.py @@ -35,7 +35,7 @@ from internlm.utils.common import ( parse_args, ) from internlm.utils.evaluation import evaluate_on_val_dls -from internlm.utils.gputest import bench_gpu, bench_net +from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.model_checkpoint import CheckpointManager @@ -73,7 +73,6 @@ def main(args): total_steps = gpc.config.data.total_steps valid_every = gpc.config.data.valid_every label_smoothing = gpc.config.loss.label_smoothing - lr = gpc.config.adam.lr get_tflops_func = partial( get_megatron_flops, @@ -96,21 +95,11 @@ def main(args): # initialize customed llm logger uniscale_logger = initialize_llm_logger(start_time=current_time) - # initialize and resume train state - train_state = TrainState(gpc.config) - # initialize model model = initialize_model() with open(args.config, "r") as f: config_lines = f.readlines() - ckpt_manager = CheckpointManager( - ckpt_config=gpc.config.ckpt, - model=model, - model_config=gpc.config.model, - model_config_file="".join(config_lines), - feishu_address=gpc.config.alert_address, - ) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -118,15 +107,25 @@ def main(args): # initialize the train and validation data loader train_dl, dataset_types = get_train_data_loader(num_worker=4) val_dls = get_validation_data_loader() - train_state.init_batch_sampler(train_dl) - # Loading model weights must be done before zero is initialized. - ckpt_manager.try_load_model(current_time) + # initialize and resume train state + train_state = TrainState(gpc.config, train_dl.batch_sampler) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + model_config_file="".join(config_lines), + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + # Loading other persistent training states. - ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl) + ckpt_manager.try_resume_training(train_state, current_time) # initialize customed llm writer writer = Writer( @@ -195,11 +194,7 @@ def main(args): with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: # start iterating the train data and begin training for batch_count in range(train_state.batch_count, total_steps): - if batch_count % 50 == 0: - torch.cuda.empty_cache() - bench_gpu() - bench_net() - + empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) start_time = time.time() timer("one-batch").start() @@ -241,7 +236,7 @@ def main(args): if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case logger.warning(f"Warning: skip parameter update at step {batch_count}.") send_alert_message( - address=gpc.config.alert_address, + address=gpc.config.monitor.alert.feishu_alert_address, message=f"Warning: skip parameter update at step {batch_count}.", ) @@ -301,11 +296,15 @@ if __name__ == "__main__": assert hasattr(gpc, "config") and gpc.config is not None # initialize monitor manager context - with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address): + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): try: main(args) except Exception: logger.error( f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", ) - mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc()) + mm.monitor_exception( + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() + ) diff --git a/version.txt b/version.txt index 6e8bf73..0ea3a94 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.0 +0.2.0