mirror of https://github.com/hpcaitech/ColossalAI
24 lines
586 B
Python
24 lines
586 B
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||
|
|
||
|
def test_lazy_init_ctx():
|
||
|
|
||
|
with LazyInitContext() as ctx:
|
||
|
model = nn.Linear(10, 10)
|
||
|
model.weight.zero_()
|
||
|
|
||
|
# make sure the weight is a meta tensor
|
||
|
assert model.weight.is_meta
|
||
|
|
||
|
# initialize weights
|
||
|
ctx.lazy_init_parameters(model)
|
||
|
|
||
|
# make sure the weight is not a meta tensor
|
||
|
# and initialized correctly
|
||
|
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test_lazy_init_ctx()
|