From 8754fa255376055c01aab4a3fab385454b8b7930 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 12 Dec 2022 18:25:47 +0800 Subject: [PATCH] change threshold --- chunk_codegen_run.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 88c734903..99700e1af 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -45,8 +45,9 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): with torch.no_grad(): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-6), "fx_out doesn't comply with original output" - assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-6), "fx_out doesn't comply with original output" + + assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output" # test barckward # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()