[tensor] fix kwargs in colo_tensor torch_funtion (#825)

pull/717/merge
Ziyue Jiang 2022-04-21 16:47:35 +08:00 committed by GitHub
parent eb1b89908c
commit 1a9e2c2dff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 4 deletions

View File

@ -63,6 +63,6 @@ class ColoTensor(object):
kwargs = {}
kwargs = {
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
}
return func(*args, **kwargs)

View File

@ -59,14 +59,13 @@ def test_no_wrap_op():
t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.sum(t) == torch.sum(t_ref)
assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor == None
assert lazy_t.torch_tensor().numel() == 6
if __name__ == '__main__':
test_lazy_init_tensor()
test_no_wrap_op()
# test_element_wise()