diff --git a/internlm/utils/common.py b/internlm/utils/common.py index f3b58c0..188a634 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -220,6 +220,43 @@ def get_megatron_flops( return tflops +def get_megatron_flops_2( + elapsed_time_per_iter, + checkpoint=False, + seq_len=2048, + hidden_size=12, + num_layers=32, + vocab_size=12, + global_batch_size=4, + global_world_size=1, + mlp_ratio=4, + use_swiglu=True, +): + """ + Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf + """ + + checkpoint_activations_factor = 4 if checkpoint else 3 + flashattn_activations_factor = 4.5 if checkpoint else 3.5 + + if use_swiglu: + mlp_ratio = mlp_ratio * 3 / 2 + + flops_per_iteration = ( + checkpoint_activations_factor + * (8 + mlp_ratio * 4) + * global_batch_size + * seq_len + * hidden_size**2 + * num_layers + + 4 * global_batch_size * seq_len**2 * hidden_size * num_layers * flashattn_activations_factor + + 6 * global_batch_size * seq_len * hidden_size * vocab_size + ) + + tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12)) + return tflops + + class DummyProfile: """ Dummy Profile.