|
|
@ -3,7 +3,6 @@ import torch
|
|
|
|
from colossalai.tensor import ColoTensor
|
|
|
|
from colossalai.tensor import ColoTensor
|
|
|
|
from copy import deepcopy
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_linear():
|
|
|
|
def test_linear():
|
|
|
|
in_dim = 4
|
|
|
|
in_dim = 4
|
|
|
|
out_dim = 5
|
|
|
|
out_dim = 5
|
|
|
@ -45,7 +44,6 @@ def test_linear():
|
|
|
|
# torch.nn.init.uniform_(t)
|
|
|
|
# torch.nn.init.uniform_(t)
|
|
|
|
# print(t)
|
|
|
|
# print(t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_element_wise():
|
|
|
|
def test_element_wise():
|
|
|
|
t_ref = torch.randn(3, 5)
|
|
|
|
t_ref = torch.randn(3, 5)
|
|
|
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
|
|
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
|
|
@ -66,6 +64,11 @@ def test_lazy_init_tensor():
|
|
|
|
assert lazy_t._torch_tensor == None
|
|
|
|
assert lazy_t._torch_tensor == None
|
|
|
|
assert lazy_t.torch_tensor().numel() == 6
|
|
|
|
assert lazy_t.torch_tensor().numel() == 6
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
def check_all():
|
|
|
|
|
|
|
|
test_linear()
|
|
|
|
|
|
|
|
test_element_wise()
|
|
|
|
test_no_wrap_op()
|
|
|
|
test_no_wrap_op()
|
|
|
|
# test_element_wise()
|
|
|
|
test_lazy_init_tensor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
check_all()
|
|
|
|