[tensor] redirect .data.__get__ to a tensor instance (#1239)

pull/1243/head
HELSON 2022-07-11 11:41:29 +08:00 committed by GitHub
parent 20da6e48c8
commit f6add9b720
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 12 deletions

View File

@ -2,13 +2,24 @@ from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy
import torch
from torch.overrides import get_default_nowrap_functions
from functools import lru_cache
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional
from typing import Optional, Set, Callable
@lru_cache(None)
def _get_my_nowrap_functions() -> Set[Callable]:
Tensor = torch.Tensor
return {
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
}
def _convert_output(output, pg: ProcessGroup):
@ -154,7 +165,7 @@ class ColoTensor(torch.Tensor):
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
if func in _get_my_nowrap_functions():
return ret
else:
pg = _scan_for_pg_from_args(args, kwargs)
@ -170,8 +181,9 @@ class ColoTensor(torch.Tensor):
Args:
dist_spec (_DistSpec): the target dist. spec.
"""
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
@ -182,8 +194,7 @@ class ColoTensor(torch.Tensor):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group)
self.dist_spec = distspec.replicate()
self._convert_to_dist_spec(dist_spec=distspec.replicate())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
@ -223,12 +234,8 @@ class ColoTensor(torch.Tensor):
"""
if self.is_replicate():
return super().view(*args)
# TODO(jiaruifang) check why this not work
# self.data = self.to_replicate()
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
self.process_group)
self.dist_spec = distspec.replicate()
return super().view(*args)
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):
"""override the torch buildin size()

View File

@ -86,6 +86,7 @@ def _run_tensor_shard_init(world_size):
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate())
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"