From 4a3d15650ecb410332414c4b695b2dc0125346bc Mon Sep 17 00:00:00 2001 From: Sun Peng Date: Sat, 8 Jul 2023 18:55:31 +0800 Subject: [PATCH] fix(no_pp_scheduler): drop model out data and label if not used (#39) * fix(no_pp_scheduler): drop out and label if not used * Update train_performance.md * Update readme with new tested data * update some typos --- doc/en/train_performance.md | 47 +++++++++++++++++++----- doc/train_performance.md | 51 ++++++++++++++++++++------ internlm/core/no_pipeline_scheduler.py | 6 +-- 3 files changed, 79 insertions(+), 25 deletions(-) diff --git a/doc/en/train_performance.md b/doc/en/train_performance.md index 995a0f1..bd916c4 100644 --- a/doc/en/train_performance.md +++ b/doc/en/train_performance.md @@ -46,19 +46,46 @@ Throughput is defined as TGS, the average number of tokens processed per GPU per ### FLOPS Testing -The computational workload of model training is based on the FLOPS calculation method described in the [Megatron](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) paper. To ensure constant FLOPS during training, the test configuration had `pack_sample_into_one=True`. The training used the following configuration: +The computational workload of model training is based on the FLOPS calculation method described in the [Megatron](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) paper. To ensure constant FLOPS during training, the test configuration had `pack_sample_into_one=True`, `dtype=torch.bfloat16`. -Activation Checkpointing | tp | zero-1 | seq_len | micro_num | micro_bsz | -| --- | --- | ------ | ------- | --------- | --------- | -Disabled | 1 | 8 | 2048 | 4 | 2 | -Enabled | 1 | 8 | 2048 | 1 | 8 | -The test results are shown in the table below. InternLM can achieve `>180 TFLOPS` for 7B model on thousand-card scale. +When `Activation Ckpt` is enabled,the test results are shown in the table below. InternLM can achieve `>180 TFLOPS` for 7B model training with 1024 GPUs. + +- TGS: Tokens per GPU per Second + +- Global Bsz: 一个step中所有GPU处理的token数量 + +| TP | Zero1 | Pack Sample Into One | Activation Ckpt | GPU Num | Seq Len | Micro Bsz | Micro Num | Global Bsz | TGS | TFLOPS | +|-|-|-|-|-|-|-|-|-|-|-| +| 1 | 8 | TRUE | TRUE | 8 | 2048 | 8 | 1 | 0.125M | 3314 | 193 | +| 1 | 8 | TRUE | TRUE | 16 | 2048 | 8 | 1 | 0.25M | 3268 | 191 | +| 1 | 8 | TRUE | TRUE | 32 | 2048 | 8 | 1 | 0.5M | 3323 | 188 | +| 1 | 8 | TRUE | TRUE | 64 | 2048 | 8 | 1 | 1M | 3217 | 188 | +| 1 | 8 | TRUE | TRUE | 128 | 2048 | 8 | 1 | 2M | 3260 | 187 | +| 1 | 8 | TRUE | TRUE | 256 | 2048 | 8 | 1 | 4M | 3215 | 187 | +| 1 | 8 | TRUE | TRUE | 512 | 2048 | 8 | 1 | 8M | 3199 | 186 | +| 1 | 8 | TRUE | TRUE | 1024 | 2048 | 8 | 1 | 16M | 3163 | 184 | +| 1 | 8 | TRUE | TRUE | 512 | 2048 | 4 | 1 | 4M | 2963 | 173 | +| 1 | 8 | TRUE | TRUE | 1024 | 2048 | 2 | 1 | 4M | 2341 | 136 | +| 1 | 8 | TRUE | TRUE | 1024 | 2048 | 4 | 1 | 8M | 2796 | 160 | + +When `Activation Ckpt` is turned off, the test results are as shown in the table below: + +| TP | Zero1 | Pack Sample Into One | Activation Ckpt | GPU Num | Seq Len | Micro Bsz | Micro Num | Global Bsz | TGS | TFLOPS | +|-|-|-|-|-|-|-|-|-|-|-| +| 1 | 8 | TRUE | FALSE | 8 | 2048 | 2 | 4 | 0.125M | 4103 | 183 | +| 1 | 8 | TRUE | FALSE | 16 | 2048 | 2 | 4 | 0.25M | 3939 | 177 | +| 1 | 8 | TRUE | FALSE | 32 | 2048 | 2 | 4 | 0.5M | 3919 | 176 | +| 1 | 8 | TRUE | FALSE | 64 | 2048 | 2 | 4 | 1M | 3944 | 174 | +| 1 | 8 | TRUE | FALSE | 128 | 2048 | 2 | 4 | 2M | 3928 | 173 | +| 1 | 8 | TRUE | FALSE | 256 | 2048 | 2 | 4 | 4M | 3920 | 173 | +| 1 | 8 | TRUE | FALSE | 512 | 2048 | 2 | 4 | 8M | 3900 | 173 | +| 1 | 8 | TRUE | FALSE | 1024 | 2048 | 2 | 4 | 16M | 3625 | 160 | +| 1 | 8 | TRUE | FALSE | 512 | 2048 | 2 | 2 | 4M | 3084 | 139 | +| 1 | 8 | TRUE | FALSE | 1024 | 2048 | 2 | 1 | 4M | 2346 | 105 | +| 1 | 8 | TRUE | FALSE | 1024 | 2048 | 2 | 2 | 8M | 2817 | 124 | + -| Activation Checkpoint | 8 GPUs | 16 GPUs | 32 GPUs | 64 GPUs | 128 GPUs | 256 GPUs | 512 GPUs | 1024 GPUs | -| --------------------- | ------ | ------- | ------- | ------- | -------- | -------- | -------- | --------- | -| Disabled | 183 | 177 | 176 | 174 | 173 | 173 | 173 | 160 | -| Enabled | 192 | 192 | 186 | 186 | 185 | 185 | 186 | 182 |
diff --git a/doc/train_performance.md b/doc/train_performance.md index 72c1509..f5ff0bf 100644 --- a/doc/train_performance.md +++ b/doc/train_performance.md @@ -29,7 +29,7 @@ InternLM中`zero1`的配置决定了优化器状态的分配范围。 ### 吞吐量测量 -吞吐量定义为TGS,平均每GPU每秒处理的token的数量(Tokens per GPU per Second)。在该项测试的训练配置中,`pack_sample_into_one=False`,`checkpoint=False`。测试结果如下表所示。采用`zero1=8,tp=1`,InternLM针对7B模型训练的扩展性,在千卡训练的加速效率可以达到`88%`。 +吞吐量定义为TGS,平均每GPU每秒处理的token的数量(Tokens per GPU per Second)。在该项测试的训练配置中,`pack_sample_into_one=False`,`checkpoint=False`, `dtype=torch.bfloat16`。测试结果如下表所示。采用`zero1=8,tp=1`,InternLM针对7B模型训练的扩展性,在千卡训练的加速效率可以达到`88%`。 | 并行配置 | 8卡 | 16卡 | 32卡 | 64卡 | 128卡 | 256卡 | 512卡 | 1024卡 | | ---------------- | ---- | ---- | ---- | ---- | ----- | ----- | ----- | ------ | @@ -44,19 +44,46 @@ InternLM中`zero1`的配置决定了优化器状态的分配范围。
### FLOPS测试 -模型训练的计算量参考 [Megatron](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) 论文中FLOPS计算方式。为了保证训练过程中的FLOPS恒定,在该项测试的训练配置中,`pack_sample_into_one=True`,其余超参设置如下所示: +模型训练的计算量参考 [Megatron](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) 论文中FLOPS计算方式。为了保证训练过程中的FLOPS恒定,在该项测试的训练配置中,`pack_sample_into_one=True`,`dtype=torch.bfloat16`。 -activation checkpoint | tp | zero-1 | seq_len | micro_num | micro_bsz | -| --- | --- | ---- | ---- | ---- |---- | -关闭 | 1 | 8 | 2048 | 4 | 2 | -开启 | 1 | 8 | 2048 | 1 | 8 | -测试结果如下表所示,InternLM针对7B模型的千卡训练,可以达到 `>180 TFLOPS`: -| activation checkpoint | 8卡 | 16卡 | 32卡 | 64卡 | 128卡 | 256卡 | 512卡 | 1024卡 | -| --------------- | --- | ---- | ---- | ---- | ----- | ----- | ----- | ------ | -| 关闭 | 183 | 177 | 176 | 174 | 173 | 173 | 173 | 160 | -| 开启 | 192 | 192 | 186 | 186 | 185 | 185 | 186 | 182 | +当开启 Activation Ckpt后,测试结果如下表所示,InternLM针对7B模型的千卡训练,可以达到 `>180 TFLOPS`: + +- TGS: Tokens per GPU per Second + +- Global Bsz: The total number of processed tokens with all GPUs in a step + +| TP | Zero1 | Pack Sample Into One | Activation Ckpt | GPU Num | Seq Len | Micro Bsz | Micro Num | Global Bsz | TGS | TFLOPS | +|-|-|-|-|-|-|-|-|-|-|-| +| 1 | 8 | TRUE | TRUE | 8 | 2048 | 8 | 1 | 0.125M | 3314 | 193 | +| 1 | 8 | TRUE | TRUE | 16 | 2048 | 8 | 1 | 0.25M | 3268 | 191 | +| 1 | 8 | TRUE | TRUE | 32 | 2048 | 8 | 1 | 0.5M | 3323 | 188 | +| 1 | 8 | TRUE | TRUE | 64 | 2048 | 8 | 1 | 1M | 3217 | 188 | +| 1 | 8 | TRUE | TRUE | 128 | 2048 | 8 | 1 | 2M | 3260 | 187 | +| 1 | 8 | TRUE | TRUE | 256 | 2048 | 8 | 1 | 4M | 3215 | 187 | +| 1 | 8 | TRUE | TRUE | 512 | 2048 | 8 | 1 | 8M | 3199 | 186 | +| 1 | 8 | TRUE | TRUE | 1024 | 2048 | 8 | 1 | 16M | 3163 | 184 | +| 1 | 8 | TRUE | TRUE | 512 | 2048 | 4 | 1 | 4M | 2963 | 173 | +| 1 | 8 | TRUE | TRUE | 1024 | 2048 | 2 | 1 | 4M | 2341 | 136 | +| 1 | 8 | TRUE | TRUE | 1024 | 2048 | 4 | 1 | 8M | 2796 | 160 | + +当关闭 Activation Ckpt后,测试结果如下表所示: + +| TP | Zero1 | Pack Sample Into One | Activation Ckpt | GPU Num | Seq Len | Micro Bsz | Micro Num | Global Bsz | TGS | TFLOPS | +|-|-|-|-|-|-|-|-|-|-|-| +| 1 | 8 | TRUE | FALSE | 8 | 2048 | 2 | 4 | 0.125M | 4103 | 183 | +| 1 | 8 | TRUE | FALSE | 16 | 2048 | 2 | 4 | 0.25M | 3939 | 177 | +| 1 | 8 | TRUE | FALSE | 32 | 2048 | 2 | 4 | 0.5M | 3919 | 176 | +| 1 | 8 | TRUE | FALSE | 64 | 2048 | 2 | 4 | 1M | 3944 | 174 | +| 1 | 8 | TRUE | FALSE | 128 | 2048 | 2 | 4 | 2M | 3928 | 173 | +| 1 | 8 | TRUE | FALSE | 256 | 2048 | 2 | 4 | 4M | 3920 | 173 | +| 1 | 8 | TRUE | FALSE | 512 | 2048 | 2 | 4 | 8M | 3900 | 173 | +| 1 | 8 | TRUE | FALSE | 1024 | 2048 | 2 | 4 | 16M | 3625 | 160 | +| 1 | 8 | TRUE | FALSE | 512 | 2048 | 2 | 2 | 4M | 3084 | 139 | +| 1 | 8 | TRUE | FALSE | 1024 | 2048 | 2 | 1 | 4M | 2346 | 105 | +| 1 | 8 | TRUE | FALSE | 1024 | 2048 | 2 | 2 | 8M | 2817 | 124 |
-
\ No newline at end of file + + diff --git a/internlm/core/no_pipeline_scheduler.py b/internlm/core/no_pipeline_scheduler.py index 6cd8416..1f201e5 100644 --- a/internlm/core/no_pipeline_scheduler.py +++ b/internlm/core/no_pipeline_scheduler.py @@ -269,9 +269,9 @@ class NonPipelineScheduler(BaseScheduler): if return_loss: loss += _loss - - outputs.append(_output) - labels.append(_label) + if return_output_label: + outputs.append(_output) + labels.append(_label) if not return_output_label: outputs, labels = None, None