|
|
@ -3,6 +3,7 @@ import torch |
|
|
|
from colossalai.tensor import ColoTensor |
|
|
|
from colossalai.tensor import ColoTensor |
|
|
|
from copy import deepcopy |
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_linear(): |
|
|
|
def test_linear(): |
|
|
|
in_dim = 4 |
|
|
|
in_dim = 4 |
|
|
|
out_dim = 5 |
|
|
|
out_dim = 5 |
|
|
@ -44,6 +45,7 @@ def test_linear(): |
|
|
|
# torch.nn.init.uniform_(t) |
|
|
|
# torch.nn.init.uniform_(t) |
|
|
|
# print(t) |
|
|
|
# print(t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_element_wise(): |
|
|
|
def test_element_wise(): |
|
|
|
t_ref = torch.randn(3, 5) |
|
|
|
t_ref = torch.randn(3, 5) |
|
|
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone()) |
|
|
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone()) |
|
|
@ -59,10 +61,12 @@ def test_no_wrap_op(): |
|
|
|
assert torch.sum(t) == torch.sum(t_ref) |
|
|
|
assert torch.sum(t) == torch.sum(t_ref) |
|
|
|
assert torch.sum(input=t) == torch.sum(input=t_ref) |
|
|
|
assert torch.sum(input=t) == torch.sum(input=t_ref) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_lazy_init_tensor(): |
|
|
|
def test_lazy_init_tensor(): |
|
|
|
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True) |
|
|
|
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True) |
|
|
|
assert lazy_t._torch_tensor.numel() == 0 |
|
|
|
assert lazy_t._torch_tensor.numel() == 0 |
|
|
|
assert lazy_t.torch_tensor().numel() == 6 |
|
|
|
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_all(): |
|
|
|
def check_all(): |
|
|
|
test_linear() |
|
|
|
test_linear() |
|
|
@ -70,5 +74,6 @@ def check_all(): |
|
|
|
test_no_wrap_op() |
|
|
|
test_no_wrap_op() |
|
|
|
test_lazy_init_tensor() |
|
|
|
test_lazy_init_tensor() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
if __name__ == '__main__': |
|
|
|
check_all() |
|
|
|
test_lazy_init_tensor() |
|
|
|