[hotfix] pass test_complete_workflow (#1877)

pull/1880/head
Jiarui Fang 2022-11-10 17:53:39 +08:00 committed by GitHub
parent 986f8cbaa7
commit 51597f6a28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -50,7 +50,7 @@ def run_workflow(world_size, dev):
annotated_gm.recompile()
# materialization and sharding
ctx.lazy_init_parameters(annotated_gm)
ctx.lazy_init_parameters(annotated_gm, device=dev)
for param in model.parameters():
assert not param.is_meta
@ -84,4 +84,4 @@ def test_complete_workflow(world_size, dev):
if __name__ == '__main__':
test_complete_workflow(1)
test_complete_workflow(1, 'cuda')