mirror of https://github.com/InternLM/InternLM
add a new get_tflops_func
parent
cc20fa271a
commit
3253cbf48e
|
@ -220,6 +220,43 @@ def get_megatron_flops(
|
||||||
return tflops
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
|
def get_megatron_flops_2(
|
||||||
|
elapsed_time_per_iter,
|
||||||
|
checkpoint=False,
|
||||||
|
seq_len=2048,
|
||||||
|
hidden_size=12,
|
||||||
|
num_layers=32,
|
||||||
|
vocab_size=12,
|
||||||
|
global_batch_size=4,
|
||||||
|
global_world_size=1,
|
||||||
|
mlp_ratio=4,
|
||||||
|
use_swiglu=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
checkpoint_activations_factor = 4 if checkpoint else 3
|
||||||
|
flashattn_activations_factor = 4.5 if checkpoint else 3.5
|
||||||
|
|
||||||
|
if use_swiglu:
|
||||||
|
mlp_ratio = mlp_ratio * 3 / 2
|
||||||
|
|
||||||
|
flops_per_iteration = (
|
||||||
|
checkpoint_activations_factor
|
||||||
|
* (8 + mlp_ratio * 4)
|
||||||
|
* global_batch_size
|
||||||
|
* seq_len
|
||||||
|
* hidden_size**2
|
||||||
|
* num_layers
|
||||||
|
+ 4 * global_batch_size * seq_len**2 * hidden_size * num_layers * flashattn_activations_factor
|
||||||
|
+ 6 * global_batch_size * seq_len * hidden_size * vocab_size
|
||||||
|
)
|
||||||
|
|
||||||
|
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
||||||
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
class DummyProfile:
|
class DummyProfile:
|
||||||
"""
|
"""
|
||||||
Dummy Profile.
|
Dummy Profile.
|
||||||
|
|
Loading…
Reference in New Issue