mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix autoparallel compatibility test issues (#2754)
parent
0f392d7403
commit
819e25d8b1
|
@ -330,6 +330,7 @@ def autoparallelize(model: nn.Module,
|
||||||
device_mesh,
|
device_mesh,
|
||||||
solver_preference=solver_preference,
|
solver_preference=solver_preference,
|
||||||
dataloader_option=dataloader_option,
|
dataloader_option=dataloader_option,
|
||||||
|
shard_option=shard_option,
|
||||||
save_solver_solution=save_solver_solution,
|
save_solver_solution=save_solver_solution,
|
||||||
load_solver_solution=load_solver_solution,
|
load_solver_solution=load_solver_solution,
|
||||||
solution_path=solver_solution_path,
|
solution_path=solver_solution_path,
|
||||||
|
|
|
@ -33,11 +33,15 @@ def check_compatibility_with_ddp(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
model = MLP(4).cuda()
|
model = MLP(4).cuda()
|
||||||
input = torch.rand(4, 4).cuda()
|
if rank in [0, 1]:
|
||||||
output_compare = model(input)
|
input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda()
|
||||||
|
elif rank in [2, 3]:
|
||||||
|
input = torch.arange(16, 32, dtype=torch.float).reshape(4, 4).cuda()
|
||||||
|
input_compare = torch.arange(0, 32, dtype=torch.float).reshape(8, 4).cuda()
|
||||||
|
output_compare = model(input_compare)
|
||||||
loss_compare = output_compare.sum()
|
loss_compare = output_compare.sum()
|
||||||
loss_compare.backward()
|
loss_compare.backward()
|
||||||
grad_compare = copy.deepcopy(model.linear_1.weight.grad)
|
grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2)
|
||||||
|
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
|
@ -70,7 +74,10 @@ def check_compatibility_with_ddp(rank, world_size, port):
|
||||||
gm = DDP(gm, process_group=dp_process_group)
|
gm = DDP(gm, process_group=dp_process_group)
|
||||||
output = gm(input)
|
output = gm(input)
|
||||||
|
|
||||||
assert_close(output, output_compare)
|
if rank in (0, 1):
|
||||||
|
assert_close(output, output_compare.narrow(0, 0, 4))
|
||||||
|
else:
|
||||||
|
assert_close(output, output_compare.narrow(0, 4, 4))
|
||||||
print(f'output on rank{rank} is correct')
|
print(f'output on rank{rank} is correct')
|
||||||
loss = output.sum()
|
loss = output.sum()
|
||||||
|
|
||||||
|
|
|
@ -37,12 +37,15 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
model = MLP(4).half().cuda()
|
model = MLP(4).half().cuda()
|
||||||
|
if rank in [0, 1]:
|
||||||
input = torch.rand(4, 4).half().cuda()
|
input = torch.arange(0, 16).reshape(4, 4).half().cuda()
|
||||||
output_compare = model(input)
|
elif rank in [2, 3]:
|
||||||
|
input = torch.arange(16, 32).reshape(4, 4).half().cuda()
|
||||||
|
input_compare = torch.arange(0, 32).reshape(8, 4).half().cuda()
|
||||||
|
output_compare = model(input_compare)
|
||||||
loss_compare = output_compare.sum()
|
loss_compare = output_compare.sum()
|
||||||
loss_compare.backward()
|
loss_compare.backward()
|
||||||
grad_compare = copy.deepcopy(model.linear_1.weight.grad)
|
grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2)
|
||||||
|
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
|
@ -79,7 +82,10 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
||||||
optimizer = HybridAdam(gm.parameters(), betas=(0, 0))
|
optimizer = HybridAdam(gm.parameters(), betas=(0, 0))
|
||||||
optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1)
|
optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1)
|
||||||
output = gm(input)
|
output = gm(input)
|
||||||
assert_close(output, output_compare)
|
if rank in (0, 1):
|
||||||
|
assert_close(output, output_compare.narrow(0, 0, 4))
|
||||||
|
else:
|
||||||
|
assert_close(output, output_compare.narrow(0, 4, 4))
|
||||||
print(f'output on rank{rank} is correct')
|
print(f'output on rank{rank} is correct')
|
||||||
loss = output.sum()
|
loss = output.sum()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
Loading…
Reference in New Issue