mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] pass test_complete_workflow (#1877)
parent
986f8cbaa7
commit
51597f6a28
|
@ -50,7 +50,7 @@ def run_workflow(world_size, dev):
|
|||
annotated_gm.recompile()
|
||||
|
||||
# materialization and sharding
|
||||
ctx.lazy_init_parameters(annotated_gm)
|
||||
ctx.lazy_init_parameters(annotated_gm, device=dev)
|
||||
for param in model.parameters():
|
||||
assert not param.is_meta
|
||||
|
||||
|
@ -84,4 +84,4 @@ def test_complete_workflow(world_size, dev):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_complete_workflow(1)
|
||||
test_complete_workflow(1, 'cuda')
|
||||
|
|
Loading…
Reference in New Issue