From 1a9e2c2dffb09a808dcee11b2532b7cff358b450 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Thu, 21 Apr 2022 16:47:35 +0800 Subject: [PATCH] [tensor] fix kwargs in colo_tensor torch_funtion (#825) --- colossalai/tensor/colo_tensor.py | 2 +- tests/test_tensor/test_op.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index cfaac0331..6ed82aea9 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -63,6 +63,6 @@ class ColoTensor(object): kwargs = {} kwargs = { - kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs + k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items() } return func(*args, **kwargs) diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 3d1719eae..c45dca8da 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -59,14 +59,13 @@ def test_no_wrap_op(): t_ref = torch.randn(3, 5) t = ColoTensor.init_from_torch_tensor(t_ref.clone()) assert torch.sum(t) == torch.sum(t_ref) - + assert torch.sum(input=t) == torch.sum(input=t_ref) def test_lazy_init_tensor(): lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True) assert lazy_t._torch_tensor == None assert lazy_t.torch_tensor().numel() == 6 - if __name__ == '__main__': - test_lazy_init_tensor() + test_no_wrap_op() # test_element_wise()