Browse Source

[Tensor] update ColoTensor torch_function (#822)

* Revert "[zero] add ZeroTensorShardStrategy (#793)"

This reverts commit 88759e289e.

* [gemini] set cpu memory capacity

* [log] local throughput collecting

* polish

* polish

* polish

* polish code

* polish

* polish code

* add a new tensor structure and override linear for it

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* [tensor] renaming and reorganize directory structure.

* rm useless dir

* polish

* polish

* [tensor] hander the function not wrapped

* polish
pull/824/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
68dcd51d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      colossalai/tensor/colo_tensor.py

13
colossalai/tensor/colo_tensor.py

@ -24,6 +24,13 @@ class ColoTensor(object):
for kwarg in kwargs.values():
if isinstance(kwarg, ColoTensor):
return _COLOSSAL_OPS[func](types, args, kwargs, None)
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for ColoTensor!")
else:
# If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
if kwargs is None:
kwargs = {}
kwargs = {
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs
}
return func(*args, **kwargs)

Loading…
Cancel
Save