|
|
@ -58,17 +58,8 @@ def exam_zero_1_2_grad_acc():
|
|
|
|
assert torch.equal(zero1_output, zero2_output)
|
|
|
|
assert torch.equal(zero1_output, zero2_output)
|
|
|
|
|
|
|
|
|
|
|
|
# zero-dp backward
|
|
|
|
# zero-dp backward
|
|
|
|
no_sync = number == 0
|
|
|
|
zero1_optimizer.backward(zero1_output.sum().float())
|
|
|
|
with conditional_context(zero1_optimizer.no_sync(), no_sync):
|
|
|
|
zero2_optimizer.backward(zero2_output.sum().float())
|
|
|
|
zero1_optimizer.backward(zero1_output.sum().float())
|
|
|
|
|
|
|
|
with conditional_context(zero2_optimizer.no_sync(), no_sync):
|
|
|
|
|
|
|
|
zero2_optimizer.backward(zero2_output.sum().float())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if check_flag:
|
|
|
|
|
|
|
|
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
|
|
|
|
|
|
|
if z2p.grad is not None:
|
|
|
|
|
|
|
|
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
|
|
|
|
|
|
|
assert torch.equal(z1p.grad, z2p.grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fwd_bwd_func(0, input_data1, True)
|
|
|
|
fwd_bwd_func(0, input_data1, True)
|
|
|
|
fwd_bwd_func(1, input_data2, False)
|
|
|
|
fwd_bwd_func(1, input_data2, False)
|
|
|
@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc():
|
|
|
|
assert torch.equal(z1p.data, z2p.data)
|
|
|
|
assert torch.equal(z1p.data, z2p.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exam_zero_1_grad_acc():
|
|
|
|
def exam_zero_1_grad_acc(sync):
|
|
|
|
local_rank = torch.distributed.get_rank()
|
|
|
|
local_rank = torch.distributed.get_rank()
|
|
|
|
seed_all(2008)
|
|
|
|
seed_all(2008)
|
|
|
|
|
|
|
|
|
|
|
@ -112,9 +103,8 @@ def exam_zero_1_grad_acc():
|
|
|
|
input_data1 = torch.randn(32, 128).cuda()
|
|
|
|
input_data1 = torch.randn(32, 128).cuda()
|
|
|
|
input_data2 = torch.randn(32, 128).cuda()
|
|
|
|
input_data2 = torch.randn(32, 128).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
def fwd_bwd_func(number, cur_data, check_flag):
|
|
|
|
def fwd_bwd_func(no_sync, cur_data, check_flag):
|
|
|
|
|
|
|
|
|
|
|
|
no_sync = number == 0
|
|
|
|
|
|
|
|
# zero1 fwd and bwd
|
|
|
|
# zero1 fwd and bwd
|
|
|
|
with conditional_context(zero_optimizer.no_sync(), no_sync):
|
|
|
|
with conditional_context(zero_optimizer.no_sync(), no_sync):
|
|
|
|
zero_output = zero_model(cur_data)
|
|
|
|
zero_output = zero_model(cur_data)
|
|
|
@ -131,8 +121,8 @@ def exam_zero_1_grad_acc():
|
|
|
|
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
|
|
|
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
|
|
|
assert torch.equal(p.grad, z1p.grad)
|
|
|
|
assert torch.equal(p.grad, z1p.grad)
|
|
|
|
|
|
|
|
|
|
|
|
fwd_bwd_func(0, input_data1, True)
|
|
|
|
fwd_bwd_func(sync, input_data1, sync)
|
|
|
|
fwd_bwd_func(1, input_data2, False)
|
|
|
|
fwd_bwd_func(False, input_data2, False)
|
|
|
|
|
|
|
|
|
|
|
|
zero_optimizer.step()
|
|
|
|
zero_optimizer.step()
|
|
|
|
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
|
|
|
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
|
|
@ -147,9 +137,9 @@ def exam_zero_1_grad_acc():
|
|
|
|
def run_dist(rank, world_size, port):
|
|
|
|
def run_dist(rank, world_size, port):
|
|
|
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
|
|
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
|
|
|
|
|
|
|
|
|
|
|
exam_zero_1_grad_acc()
|
|
|
|
exam_zero_1_grad_acc(sync=True)
|
|
|
|
# gradient accumulation is not compatible with ZeRO-2
|
|
|
|
exam_zero_1_grad_acc(sync=False)
|
|
|
|
# exam_zero_1_2_grad_acc()
|
|
|
|
exam_zero_1_2_grad_acc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@pytest.mark.dist
|
|
|
|