diff --git a/tests/test_fx/test_complete_workflow.py b/tests/test_fx/test_complete_workflow.py index bb1a66812..a21a351f8 100644 --- a/tests/test_fx/test_complete_workflow.py +++ b/tests/test_fx/test_complete_workflow.py @@ -50,7 +50,7 @@ def run_workflow(world_size, dev): annotated_gm.recompile() # materialization and sharding - ctx.lazy_init_parameters(annotated_gm) + ctx.lazy_init_parameters(annotated_gm, device=dev) for param in model.parameters(): assert not param.is_meta @@ -84,4 +84,4 @@ def test_complete_workflow(world_size, dev): if __name__ == '__main__': - test_complete_workflow(1) + test_complete_workflow(1, 'cuda')