|
|
|
@ -67,12 +67,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
|
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
|
|
|
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
|
|
|
|
try:
|
|
|
|
|
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
|
|
|
print(f"{rank=},passed grad: {n1}, {n2}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}")
|
|
|
|
|
raise e
|
|
|
|
|
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
|
|
|
|
|
|
|
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
|
|
|
|
grads_to_check = {}
|
|
|
|
@ -108,25 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|
|
|
|
# print(grads_to_check)
|
|
|
|
|
check_all_grad_tensors(grads_to_check)
|
|
|
|
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
|
|
|
|
try:
|
|
|
|
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
|
|
|
print(f"{rank=},passed param before step: {n1}, {n2}")
|
|
|
|
|
except Exception:
|
|
|
|
|
print(
|
|
|
|
|
f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}"
|
|
|
|
|
)
|
|
|
|
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
|
|
|
|
|
|
|
|
# optimizer executes step
|
|
|
|
|
org_optimizer.step()
|
|
|
|
|
sharded_optimizer.step()
|
|
|
|
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
|
|
|
|
try:
|
|
|
|
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
|
|
|
print(f"{rank=},passed param after step: {n1}, {n2}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(
|
|
|
|
|
f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}"
|
|
|
|
|
)
|
|
|
|
|
raise e
|
|
|
|
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
|
|
|
|
|
|
|
|
# check weights
|
|
|
|
|
if stage_manager is None or stage_manager.is_first_stage():
|
|
|
|
|
if test_config["precision"] == "fp32":
|
|
|
|
|