add a new get_tflops_func

pull/456/head
mwiacx 2023-10-26 20:21:46 +08:00
parent cc20fa271a
commit 3253cbf48e
1 changed files with 37 additions and 0 deletions

View File

@ -220,6 +220,43 @@ def get_megatron_flops(
return tflops
def get_megatron_flops_2(
elapsed_time_per_iter,
checkpoint=False,
seq_len=2048,
hidden_size=12,
num_layers=32,
vocab_size=12,
global_batch_size=4,
global_world_size=1,
mlp_ratio=4,
use_swiglu=True,
):
"""
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
"""
checkpoint_activations_factor = 4 if checkpoint else 3
flashattn_activations_factor = 4.5 if checkpoint else 3.5
if use_swiglu:
mlp_ratio = mlp_ratio * 3 / 2
flops_per_iteration = (
checkpoint_activations_factor
* (8 + mlp_ratio * 4)
* global_batch_size
* seq_len
* hidden_size**2
* num_layers
+ 4 * global_batch_size * seq_len**2 * hidden_size * num_layers * flashattn_activations_factor
+ 6 * global_batch_size * seq_len * hidden_size * vocab_size
)
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
return tflops
class DummyProfile:
"""
Dummy Profile.