From 68dcd51d41aa7871cec13e6ea4c489c8c299cbe1 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 21 Apr 2022 14:25:27 +0800 Subject: [PATCH] [Tensor] update ColoTensor torch_function (#822) * Revert "[zero] add ZeroTensorShardStrategy (#793)" This reverts commit 88759e289efd0a7b5e0d7bf8e01dbe29db85cf71. * [gemini] set cpu memory capacity * [log] local throughput collecting * polish * polish * polish * polish code * polish * polish code * add a new tensor structure and override linear for it * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * [tensor] renaming and reorganize directory structure. * rm useless dir * polish * polish * [tensor] hander the function not wrapped * polish --- colossalai/tensor/colo_tensor.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 47e693720..6f82f2c07 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -24,6 +24,13 @@ class ColoTensor(object): for kwarg in kwargs.values(): if isinstance(kwarg, ColoTensor): return _COLOSSAL_OPS[func](types, args, kwargs, None) + else: + # If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors. + args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args] + if kwargs is None: + kwargs = {} - raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported for ColoTensor!") + kwargs = { + kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs + } + return func(*args, **kwargs)