|
|
|
@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
|
|
|
|
|
return splited_grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exam_zero_1_2():
|
|
|
|
|
@parameterize("fp8_communication", [True, False])
|
|
|
|
|
def exam_zero_1_2(fp8_communication: bool):
|
|
|
|
|
"""
|
|
|
|
|
In this test, we want to test whether zero stage 1 and 2
|
|
|
|
|
deliver the same numerical results despite different communication
|
|
|
|
@ -73,10 +74,18 @@ def exam_zero_1_2():
|
|
|
|
|
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
|
|
|
|
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
|
|
|
|
zero1_optimizer = LowLevelZeroOptimizer(
|
|
|
|
|
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True
|
|
|
|
|
zero1_optimizer,
|
|
|
|
|
overlap_communication=True,
|
|
|
|
|
initial_scale=128,
|
|
|
|
|
verbose=True,
|
|
|
|
|
fp8_communication=fp8_communication,
|
|
|
|
|
)
|
|
|
|
|
zero2_optimizer = LowLevelZeroOptimizer(
|
|
|
|
|
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128
|
|
|
|
|
zero2_optimizer,
|
|
|
|
|
overlap_communication=True,
|
|
|
|
|
partition_grad=True,
|
|
|
|
|
initial_scale=128,
|
|
|
|
|
fp8_communication=fp8_communication,
|
|
|
|
|
)
|
|
|
|
|
# create data
|
|
|
|
|
seed_all(2001 + local_rank)
|
|
|
|
@ -97,7 +106,10 @@ def exam_zero_1_2():
|
|
|
|
|
if g1 is None or g2 is None:
|
|
|
|
|
assert g1 is None and g2 is None
|
|
|
|
|
continue
|
|
|
|
|
assert torch.allclose(g1, g2)
|
|
|
|
|
if fp8_communication:
|
|
|
|
|
loose_close(g1, g2, dtype=torch.float16)
|
|
|
|
|
else:
|
|
|
|
|
assert torch.allclose(g1, g2)
|
|
|
|
|
|
|
|
|
|
# step
|
|
|
|
|
zero1_optimizer.step()
|
|
|
|
@ -105,7 +117,8 @@ def exam_zero_1_2():
|
|
|
|
|
|
|
|
|
|
# check updated param
|
|
|
|
|
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
|
|
|
|
assert torch.allclose(z1p, z2p)
|
|
|
|
|
if not fp8_communication:
|
|
|
|
|
assert torch.allclose(z1p, z2p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
|
|
|
|