mirror of https://github.com/hpcaitech/ColossalAI
[tensor] fix kwargs in colo_tensor torch_funtion (#825)
parent
eb1b89908c
commit
1a9e2c2dff
|
@ -63,6 +63,6 @@ class ColoTensor(object):
|
|||
kwargs = {}
|
||||
|
||||
kwargs = {
|
||||
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs
|
||||
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
|
||||
}
|
||||
return func(*args, **kwargs)
|
||||
|
|
|
@ -59,14 +59,13 @@ def test_no_wrap_op():
|
|||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
assert torch.sum(t) == torch.sum(t_ref)
|
||||
|
||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor == None
|
||||
assert lazy_t.torch_tensor().numel() == 6
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lazy_init_tensor()
|
||||
test_no_wrap_op()
|
||||
# test_element_wise()
|
||||
|
|
Loading…
Reference in New Issue