|
|
|
@ -67,12 +67,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|
|
|
|
rank = dist.get_rank() |
|
|
|
|
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): |
|
|
|
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): |
|
|
|
|
try: |
|
|
|
|
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False) |
|
|
|
|
print(f"{rank=},passed grad: {n1}, {n2}") |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}") |
|
|
|
|
raise e |
|
|
|
|
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False) |
|
|
|
|
|
|
|
|
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. |
|
|
|
|
grads_to_check = {} |
|
|
|
@ -108,25 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|
|
|
|
# print(grads_to_check) |
|
|
|
|
check_all_grad_tensors(grads_to_check) |
|
|
|
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): |
|
|
|
|
try: |
|
|
|
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) |
|
|
|
|
print(f"{rank=},passed param before step: {n1}, {n2}") |
|
|
|
|
except Exception: |
|
|
|
|
print( |
|
|
|
|
f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}" |
|
|
|
|
) |
|
|
|
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) |
|
|
|
|
|
|
|
|
|
# optimizer executes step |
|
|
|
|
org_optimizer.step() |
|
|
|
|
sharded_optimizer.step() |
|
|
|
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): |
|
|
|
|
try: |
|
|
|
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) |
|
|
|
|
print(f"{rank=},passed param after step: {n1}, {n2}") |
|
|
|
|
except Exception as e: |
|
|
|
|
print( |
|
|
|
|
f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}" |
|
|
|
|
) |
|
|
|
|
raise e |
|
|
|
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) |
|
|
|
|
|
|
|
|
|
# check weights |
|
|
|
|
if stage_manager is None or stage_manager.is_first_stage(): |
|
|
|
|
if test_config["precision"] == "fp32": |
|
|
|
|