|
|
|
@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
|
|
|
|
|
return splited_grad |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exam_zero_1_2(): |
|
|
|
|
@parameterize("fp8_communication", [True, False]) |
|
|
|
|
def exam_zero_1_2(fp8_communication: bool): |
|
|
|
|
""" |
|
|
|
|
In this test, we want to test whether zero stage 1 and 2 |
|
|
|
|
deliver the same numerical results despite different communication |
|
|
|
@ -73,10 +74,18 @@ def exam_zero_1_2():
|
|
|
|
|
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) |
|
|
|
|
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) |
|
|
|
|
zero1_optimizer = LowLevelZeroOptimizer( |
|
|
|
|
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True |
|
|
|
|
zero1_optimizer, |
|
|
|
|
overlap_communication=True, |
|
|
|
|
initial_scale=128, |
|
|
|
|
verbose=True, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
) |
|
|
|
|
zero2_optimizer = LowLevelZeroOptimizer( |
|
|
|
|
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128 |
|
|
|
|
zero2_optimizer, |
|
|
|
|
overlap_communication=True, |
|
|
|
|
partition_grad=True, |
|
|
|
|
initial_scale=128, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
) |
|
|
|
|
# create data |
|
|
|
|
seed_all(2001 + local_rank) |
|
|
|
@ -97,7 +106,10 @@ def exam_zero_1_2():
|
|
|
|
|
if g1 is None or g2 is None: |
|
|
|
|
assert g1 is None and g2 is None |
|
|
|
|
continue |
|
|
|
|
assert torch.allclose(g1, g2) |
|
|
|
|
if fp8_communication: |
|
|
|
|
loose_close(g1, g2, dtype=torch.float16) |
|
|
|
|
else: |
|
|
|
|
assert torch.allclose(g1, g2) |
|
|
|
|
|
|
|
|
|
# step |
|
|
|
|
zero1_optimizer.step() |
|
|
|
@ -105,7 +117,8 @@ def exam_zero_1_2():
|
|
|
|
|
|
|
|
|
|
# check updated param |
|
|
|
|
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): |
|
|
|
|
assert torch.allclose(z1p, z2p) |
|
|
|
|
if not fp8_communication: |
|
|
|
|
assert torch.allclose(z1p, z2p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@parameterize("dtype", [torch.float16, torch.bfloat16]) |
|
|
|
|