[fix] fix mem assertation

pull/6034/head
duanjunwen 3 months ago
parent 400e5e5b23
commit 35a7b636b3

@ -2,8 +2,7 @@ from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
# from .chatglm2 import *
from .chatglm2 import *
from .command import *
from .deepseek import *
from .falcon import *

@ -611,20 +611,24 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
optimizer_pp.step()
torch.cuda.memory_allocated() / 1024**3
after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3
# assert memory
if rank != 0:
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
# output hid_dim * hid_dim * 4(fp32) / 1024**3
# assert (after_pp_step_memory - after_init_memory) == (in_dim * in_dim * 4 * 3 / 1024**3)
pass
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} == {(in_dim * in_dim * 4 * 3 / 1024**3)}")
assert (after_pp_step_memory - after_init_memory) == (in_dim * in_dim * 4 * 3 / 1024**3)
# pass
else:
# rank0 will also hold output;
# assert round((after_pp_step_memory - after_init_memory), 5) == round(
# (in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
# )
pass
print(
f"rank {rank}: {(after_pp_step_memory - after_init_memory)} == {(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3)}"
)
assert round((after_pp_step_memory - after_init_memory), 5) == round(
(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
)
# pass
##########################
# Fwd bwd for base
##########################

Loading…
Cancel
Save