mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] pass test_complete_workflow (#1877)
parent
986f8cbaa7
commit
51597f6a28
|
@ -50,7 +50,7 @@ def run_workflow(world_size, dev):
|
||||||
annotated_gm.recompile()
|
annotated_gm.recompile()
|
||||||
|
|
||||||
# materialization and sharding
|
# materialization and sharding
|
||||||
ctx.lazy_init_parameters(annotated_gm)
|
ctx.lazy_init_parameters(annotated_gm, device=dev)
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
assert not param.is_meta
|
assert not param.is_meta
|
||||||
|
|
||||||
|
@ -84,4 +84,4 @@ def test_complete_workflow(world_size, dev):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_complete_workflow(1)
|
test_complete_workflow(1, 'cuda')
|
||||||
|
|
Loading…
Reference in New Issue