[test] add zero fp8 test case

pull/5961/head
ver217 4 months ago
parent ae486ce005
commit 91e596d017

@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
return splited_grad return splited_grad
def exam_zero_1_2(): @parameterize("fp8_communication", [True, False])
def exam_zero_1_2(fp8_communication: bool):
""" """
In this test, we want to test whether zero stage 1 and 2 In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication deliver the same numerical results despite different communication
@ -73,10 +74,18 @@ def exam_zero_1_2():
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer( zero1_optimizer = LowLevelZeroOptimizer(
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True zero1_optimizer,
overlap_communication=True,
initial_scale=128,
verbose=True,
fp8_communication=fp8_communication,
) )
zero2_optimizer = LowLevelZeroOptimizer( zero2_optimizer = LowLevelZeroOptimizer(
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128 zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=128,
fp8_communication=fp8_communication,
) )
# create data # create data
seed_all(2001 + local_rank) seed_all(2001 + local_rank)
@ -97,6 +106,9 @@ def exam_zero_1_2():
if g1 is None or g2 is None: if g1 is None or g2 is None:
assert g1 is None and g2 is None assert g1 is None and g2 is None
continue continue
if fp8_communication:
loose_close(g1, g2, dtype=torch.float16)
else:
assert torch.allclose(g1, g2) assert torch.allclose(g1, g2)
# step # step
@ -105,6 +117,7 @@ def exam_zero_1_2():
# check updated param # check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
if not fp8_communication:
assert torch.allclose(z1p, z2p) assert torch.allclose(z1p, z2p)

Loading…
Cancel
Save