[hotfix] fix autoparallel compatibility test issues (#2754)

pull/2912/head
YuliangLiu0306 2023-02-23 17:28:36 +08:00 committed by GitHub
parent 0f392d7403
commit 819e25d8b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 9 deletions

View File

@ -330,6 +330,7 @@ def autoparallelize(model: nn.Module,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,

View File

@ -33,11 +33,15 @@ def check_compatibility_with_ddp(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = MLP(4).cuda()
input = torch.rand(4, 4).cuda()
output_compare = model(input)
if rank in [0, 1]:
input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda()
elif rank in [2, 3]:
input = torch.arange(16, 32, dtype=torch.float).reshape(4, 4).cuda()
input_compare = torch.arange(0, 32, dtype=torch.float).reshape(8, 4).cuda()
output_compare = model(input_compare)
loss_compare = output_compare.sum()
loss_compare.backward()
grad_compare = copy.deepcopy(model.linear_1.weight.grad)
grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@ -70,7 +74,10 @@ def check_compatibility_with_ddp(rank, world_size, port):
gm = DDP(gm, process_group=dp_process_group)
output = gm(input)
assert_close(output, output_compare)
if rank in (0, 1):
assert_close(output, output_compare.narrow(0, 0, 4))
else:
assert_close(output, output_compare.narrow(0, 4, 4))
print(f'output on rank{rank} is correct')
loss = output.sum()

View File

@ -37,12 +37,15 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = MLP(4).half().cuda()
input = torch.rand(4, 4).half().cuda()
output_compare = model(input)
if rank in [0, 1]:
input = torch.arange(0, 16).reshape(4, 4).half().cuda()
elif rank in [2, 3]:
input = torch.arange(16, 32).reshape(4, 4).half().cuda()
input_compare = torch.arange(0, 32).reshape(8, 4).half().cuda()
output_compare = model(input_compare)
loss_compare = output_compare.sum()
loss_compare.backward()
grad_compare = copy.deepcopy(model.linear_1.weight.grad)
grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@ -79,7 +82,10 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
optimizer = HybridAdam(gm.parameters(), betas=(0, 0))
optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1)
output = gm(input)
assert_close(output, output_compare)
if rank in (0, 1):
assert_close(output, output_compare.narrow(0, 0, 4))
else:
assert_close(output, output_compare.narrow(0, 4, 4))
print(f'output on rank{rank} is correct')
loss = output.sum()
optimizer.zero_grad()