mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix autoparallel compatibility test issues (#2754)
parent
0f392d7403
commit
819e25d8b1
|
@ -330,6 +330,7 @@ def autoparallelize(model: nn.Module,
|
|||
device_mesh,
|
||||
solver_preference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
shard_option=shard_option,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solution_path=solver_solution_path,
|
||||
|
|
|
@ -33,11 +33,15 @@ def check_compatibility_with_ddp(rank, world_size, port):
|
|||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = MLP(4).cuda()
|
||||
input = torch.rand(4, 4).cuda()
|
||||
output_compare = model(input)
|
||||
if rank in [0, 1]:
|
||||
input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda()
|
||||
elif rank in [2, 3]:
|
||||
input = torch.arange(16, 32, dtype=torch.float).reshape(4, 4).cuda()
|
||||
input_compare = torch.arange(0, 32, dtype=torch.float).reshape(8, 4).cuda()
|
||||
output_compare = model(input_compare)
|
||||
loss_compare = output_compare.sum()
|
||||
loss_compare.backward()
|
||||
grad_compare = copy.deepcopy(model.linear_1.weight.grad)
|
||||
grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2)
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
|
@ -70,7 +74,10 @@ def check_compatibility_with_ddp(rank, world_size, port):
|
|||
gm = DDP(gm, process_group=dp_process_group)
|
||||
output = gm(input)
|
||||
|
||||
assert_close(output, output_compare)
|
||||
if rank in (0, 1):
|
||||
assert_close(output, output_compare.narrow(0, 0, 4))
|
||||
else:
|
||||
assert_close(output, output_compare.narrow(0, 4, 4))
|
||||
print(f'output on rank{rank} is correct')
|
||||
loss = output.sum()
|
||||
|
||||
|
|
|
@ -37,12 +37,15 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
|||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = MLP(4).half().cuda()
|
||||
|
||||
input = torch.rand(4, 4).half().cuda()
|
||||
output_compare = model(input)
|
||||
if rank in [0, 1]:
|
||||
input = torch.arange(0, 16).reshape(4, 4).half().cuda()
|
||||
elif rank in [2, 3]:
|
||||
input = torch.arange(16, 32).reshape(4, 4).half().cuda()
|
||||
input_compare = torch.arange(0, 32).reshape(8, 4).half().cuda()
|
||||
output_compare = model(input_compare)
|
||||
loss_compare = output_compare.sum()
|
||||
loss_compare.backward()
|
||||
grad_compare = copy.deepcopy(model.linear_1.weight.grad)
|
||||
grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2)
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
|
@ -79,7 +82,10 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
|||
optimizer = HybridAdam(gm.parameters(), betas=(0, 0))
|
||||
optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1)
|
||||
output = gm(input)
|
||||
assert_close(output, output_compare)
|
||||
if rank in (0, 1):
|
||||
assert_close(output, output_compare.narrow(0, 0, 4))
|
||||
else:
|
||||
assert_close(output, output_compare.narrow(0, 4, 4))
|
||||
print(f'output on rank{rank} is correct')
|
||||
loss = output.sum()
|
||||
optimizer.zero_grad()
|
||||
|
|
Loading…
Reference in New Issue