|
|
|
@ -558,8 +558,9 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|
|
|
|
batch_size = test_config["batch_size"] |
|
|
|
|
num_layers = 8 |
|
|
|
|
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" |
|
|
|
|
in_dim = out_dim = 16 |
|
|
|
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") |
|
|
|
|
in_dim = out_dim = 4096 |
|
|
|
|
before_init_memory = torch.cuda.memory_allocated() / 1024**3 |
|
|
|
|
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") |
|
|
|
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) |
|
|
|
|
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] |
|
|
|
|
|
|
|
|
@ -595,9 +596,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|
|
|
|
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) |
|
|
|
|
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) |
|
|
|
|
|
|
|
|
|
print( |
|
|
|
|
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" |
|
|
|
|
) |
|
|
|
|
after_init_memory = torch.cuda.memory_allocated() / 1024**3 |
|
|
|
|
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") |
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
result = scheduler.forward_backward_step( |
|
|
|
@ -611,6 +611,19 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|
|
|
|
|
|
|
|
|
optimizer_pp.step() |
|
|
|
|
|
|
|
|
|
after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3 |
|
|
|
|
|
|
|
|
|
# assert memory |
|
|
|
|
if rank != 0: |
|
|
|
|
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 |
|
|
|
|
# output hid_dim * hid_dim * 4(fp32) / 1024**3 |
|
|
|
|
assert (after_pp_step_memory - after_init_memory) == (in_dim * in_dim * 4 * 3 / 1024**3) |
|
|
|
|
else: |
|
|
|
|
# TODO: |
|
|
|
|
# rank0 will also hold output |
|
|
|
|
assert round((after_pp_step_memory - after_init_memory), 5) == round( |
|
|
|
|
(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 |
|
|
|
|
) |
|
|
|
|
########################## |
|
|
|
|
# Fwd bwd for base |
|
|
|
|
########################## |
|
|
|
@ -619,7 +632,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|
|
|
|
loss_base = criterion(output_base) |
|
|
|
|
loss_base.backward() |
|
|
|
|
optimizer_base.step() |
|
|
|
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") |
|
|
|
|
|
|
|
|
|
########################## |
|
|
|
|
# assert loss & output |
|
|
|
|