mirror of https://github.com/hpcaitech/ColossalAI
[fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
parent
ce58d8e8bf
commit
8366a7855f
|
@ -509,6 +509,15 @@ def run_fwd_bwd_iter_input(test_config):
|
|||
"precision": "bf16",
|
||||
"num_model_chunk": 2,
|
||||
},
|
||||
# {
|
||||
# "batch_size": 8,
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 4,
|
||||
# "num_microbatches": 8,
|
||||
# "zero_stage": 1,
|
||||
# "precision": "bf16",
|
||||
# "num_model_chunk": 2,
|
||||
# },
|
||||
],
|
||||
)
|
||||
def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||
|
@ -593,8 +602,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
local_chunk.append(sub_model)
|
||||
|
||||
# init optimizer
|
||||
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
|
||||
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
|
||||
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
|
||||
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5))
|
||||
|
||||
after_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
|
||||
|
@ -617,15 +626,16 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
if rank != 0:
|
||||
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||
# output hid_dim * hid_dim * 4(fp32) / 1024**3
|
||||
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 3 / 1024**3)}")
|
||||
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 3 / 1024**3)
|
||||
# optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}")
|
||||
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3)
|
||||
else:
|
||||
# rank0 will also hold output;
|
||||
print(
|
||||
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
|
||||
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
|
||||
)
|
||||
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
|
||||
(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
||||
(in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
||||
)
|
||||
|
||||
##########################
|
||||
|
@ -681,10 +691,15 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
##########################
|
||||
# assert optim state
|
||||
##########################
|
||||
optim_base_state_dict = optimizer_base.state_dict()["param_groups"][0]
|
||||
optim_pp_state_dict = optimizer_pp.state_dict()["param_groups"][0]
|
||||
optim_base_state = optimizer_base.state_dict()["state"]
|
||||
optim_pp_state = optimizer_pp.state_dict()["state"]
|
||||
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
|
||||
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
|
||||
# if rank == 0:
|
||||
# print(f"optim_base_state {optim_base_state}")
|
||||
|
||||
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_state_dict.items(), optim_pp_state_dict.items()):
|
||||
# assert param group
|
||||
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
|
||||
if key_base == key_pp:
|
||||
if key_base != "params":
|
||||
assert val_base == val_pp
|
||||
|
@ -694,6 +709,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
# params pp: [0, 1];
|
||||
assert val_base[:2] == val_pp
|
||||
|
||||
# assert state
|
||||
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"])
|
||||
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
|
||||
|
||||
|
||||
# TODO:4) support Hybrid base 3)
|
||||
def run_with_hybridplugin(test_config):
|
||||
|
|
Loading…
Reference in New Issue