mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
973 B
41 lines
973 B
import pytest
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
import torch
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
|
|
@pytest.mark.skip
|
|
# FIXME(ver217): support lazy init
|
|
def test_lazy_init():
|
|
in_dim = 4
|
|
out_dim = 5
|
|
|
|
with ColoInitContext(lazy_memory_allocate=True) as ctx:
|
|
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
|
|
|
|
# lazy_memory_allocate=True, no payload is maintained
|
|
assert fc.weight._torch_tensor.numel() == 0
|
|
|
|
fc.weight.torch_tensor()
|
|
assert fc.weight._torch_tensor.numel() == in_dim * out_dim
|
|
|
|
|
|
@pytest.mark.skip
|
|
def test_device():
|
|
in_dim = 4
|
|
out_dim = 5
|
|
|
|
with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx:
|
|
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
|
|
|
|
# eval an lazy parameter
|
|
fc.weight.torch_tensor()
|
|
assert fc.weight.device == get_current_device()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_lazy_init()
|
|
test_device()
|