mirror of https://github.com/InternLM/InternLM
add a new get_tflops_func
parent
cc20fa271a
commit
3253cbf48e
|
@ -220,6 +220,43 @@ def get_megatron_flops(
|
|||
return tflops
|
||||
|
||||
|
||||
def get_megatron_flops_2(
|
||||
elapsed_time_per_iter,
|
||||
checkpoint=False,
|
||||
seq_len=2048,
|
||||
hidden_size=12,
|
||||
num_layers=32,
|
||||
vocab_size=12,
|
||||
global_batch_size=4,
|
||||
global_world_size=1,
|
||||
mlp_ratio=4,
|
||||
use_swiglu=True,
|
||||
):
|
||||
"""
|
||||
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
|
||||
"""
|
||||
|
||||
checkpoint_activations_factor = 4 if checkpoint else 3
|
||||
flashattn_activations_factor = 4.5 if checkpoint else 3.5
|
||||
|
||||
if use_swiglu:
|
||||
mlp_ratio = mlp_ratio * 3 / 2
|
||||
|
||||
flops_per_iteration = (
|
||||
checkpoint_activations_factor
|
||||
* (8 + mlp_ratio * 4)
|
||||
* global_batch_size
|
||||
* seq_len
|
||||
* hidden_size**2
|
||||
* num_layers
|
||||
+ 4 * global_batch_size * seq_len**2 * hidden_size * num_layers * flashattn_activations_factor
|
||||
+ 6 * global_batch_size * seq_len * hidden_size * vocab_size
|
||||
)
|
||||
|
||||
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
||||
return tflops
|
||||
|
||||
|
||||
class DummyProfile:
|
||||
"""
|
||||
Dummy Profile.
|
||||
|
|
Loading…
Reference in New Issue