mirror of https://github.com/hpcaitech/ColossalAI
[tensor] fix kwargs in colo_tensor torch_funtion (#825)
parent
eb1b89908c
commit
1a9e2c2dff
|
@ -63,6 +63,6 @@ class ColoTensor(object):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs
|
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
|
||||||
}
|
}
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
|
@ -59,14 +59,13 @@ def test_no_wrap_op():
|
||||||
t_ref = torch.randn(3, 5)
|
t_ref = torch.randn(3, 5)
|
||||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||||
assert torch.sum(t) == torch.sum(t_ref)
|
assert torch.sum(t) == torch.sum(t_ref)
|
||||||
|
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||||
|
|
||||||
def test_lazy_init_tensor():
|
def test_lazy_init_tensor():
|
||||||
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
|
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
|
||||||
assert lazy_t._torch_tensor == None
|
assert lazy_t._torch_tensor == None
|
||||||
assert lazy_t.torch_tensor().numel() == 6
|
assert lazy_t.torch_tensor().numel() == 6
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_lazy_init_tensor()
|
test_no_wrap_op()
|
||||||
# test_element_wise()
|
# test_element_wise()
|
||||||
|
|
Loading…
Reference in New Issue