mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
1.6 KiB
67 lines
1.6 KiB
from numpy import allclose
|
|
import torch
|
|
from colossalai.tensor import ColoTensor
|
|
from copy import deepcopy
|
|
|
|
|
|
def test_linear():
|
|
in_dim = 4
|
|
out_dim = 5
|
|
|
|
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
|
|
fc_ref = deepcopy(fc)
|
|
|
|
input_ref = torch.randn(1, in_dim)
|
|
input_tensor = input_ref.clone()
|
|
|
|
sharded_weight = ColoTensor(fc_ref.weight)
|
|
sharded_bias = ColoTensor(fc_ref.bias)
|
|
|
|
# replace the torch nn.Parameters with ShardedTensor
|
|
delattr(fc, 'weight')
|
|
setattr(fc, 'weight', sharded_weight)
|
|
delattr(fc, 'bias')
|
|
setattr(fc, 'bias', sharded_bias)
|
|
|
|
fc.weight.requires_grad = True
|
|
fc.bias.requires_grad = True
|
|
|
|
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
|
|
out = fc(input_tensor)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
|
|
out_ref = fc_ref(input_ref)
|
|
loss_ref = out_ref.sum()
|
|
loss_ref.backward()
|
|
|
|
assert (loss_ref == loss)
|
|
assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)
|
|
|
|
|
|
# The test case failed
|
|
# def test_uniform():
|
|
# t = ColoTensor(torch.zeros(3, 5))
|
|
# torch.nn.init.uniform_(t)
|
|
# print(t)
|
|
|
|
|
|
def test_element_wise():
|
|
t_ref = torch.randn(3, 5)
|
|
t = ColoTensor(t_ref.clone())
|
|
assert torch.mean(t) == torch.mean(t_ref)
|
|
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
|
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
|
|
|
|
|
# Test a function not wrapped by
|
|
def test_no_wrap_op():
|
|
t_ref = torch.randn(3, 5)
|
|
t = ColoTensor(t_ref.clone())
|
|
assert torch.sum(t) == torch.sum(t_ref)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_no_wrap_op()
|
|
# test_element_wise()
|