mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] update ColoTensor torch_function (#822)
* Revert "[zero] add ZeroTensorShardStrategy (#793)"
This reverts commit 88759e289e
.
* [gemini] set cpu memory capacity
* [log] local throughput collecting
* polish
* polish
* polish
* polish code
* polish
* polish code
* add a new tensor structure and override linear for it
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* [tensor] renaming and reorganize directory structure.
* rm useless dir
* polish
* polish
* [tensor] hander the function not wrapped
* polish
pull/824/head
parent
660d2d1f1b
commit
68dcd51d41
|
@ -24,6 +24,13 @@ class ColoTensor(object):
|
|||
for kwarg in kwargs.values():
|
||||
if isinstance(kwarg, ColoTensor):
|
||||
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||||
else:
|
||||
# If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors.
|
||||
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
|
||||
f"kwargs: {kwargs} not supported for ColoTensor!")
|
||||
kwargs = {
|
||||
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs
|
||||
}
|
||||
return func(*args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue