[fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

pull/6034/head
duanjunwen 2024-09-09 09:27:13 +00:00
parent ce58d8e8bf
commit 8366a7855f
1 changed files with 28 additions and 9 deletions

View File

@ -509,6 +509,15 @@ def run_fwd_bwd_iter_input(test_config):
"precision": "bf16",
"num_model_chunk": 2,
},
# {
# "batch_size": 8,
# "tp_size": 1,
# "pp_size": 4,
# "num_microbatches": 8,
# "zero_stage": 1,
# "precision": "bf16",
# "num_model_chunk": 2,
# },
],
)
def run_fwd_bwd_vschedule_with_optim(test_config):
@ -593,8 +602,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
local_chunk.append(sub_model)
# init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5))
after_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
@ -617,15 +626,16 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
if rank != 0:
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
# output hid_dim * hid_dim * 4(fp32) / 1024**3
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 3 / 1024**3)}")
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 3 / 1024**3)
# optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}")
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3)
else:
# rank0 will also hold output;
print(
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
)
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
(in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
)
##########################
@ -681,10 +691,15 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
##########################
# assert optim state
##########################
optim_base_state_dict = optimizer_base.state_dict()["param_groups"][0]
optim_pp_state_dict = optimizer_pp.state_dict()["param_groups"][0]
optim_base_state = optimizer_base.state_dict()["state"]
optim_pp_state = optimizer_pp.state_dict()["state"]
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
# if rank == 0:
# print(f"optim_base_state {optim_base_state}")
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_state_dict.items(), optim_pp_state_dict.items()):
# assert param group
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
if key_base == key_pp:
if key_base != "params":
assert val_base == val_pp
@ -694,6 +709,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# params pp: [0, 1];
assert val_base[:2] == val_pp
# assert state
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"])
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
# TODO:4) support Hybrid base 3)
def run_with_hybridplugin(test_config):