diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 17c30ad34..9f036220f 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -2,13 +2,24 @@ from .op_wrapper import _COLOSSAL_OPS from .const import TensorType from copy import copy import torch -from torch.overrides import get_default_nowrap_functions +from functools import lru_cache from colossalai.tensor import ColoTensorSpec from colossalai.tensor import distspec, ProcessGroup from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern -from typing import Optional +from typing import Optional, Set, Callable + + +@lru_cache(None) +def _get_my_nowrap_functions() -> Set[Callable]: + Tensor = torch.Tensor + return { + Tensor._base.__get__, + Tensor.grad.__get__, + Tensor._grad.__get__, + Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor + } def _convert_output(output, pg: ProcessGroup): @@ -154,7 +165,7 @@ class ColoTensor(torch.Tensor): with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) - if func in get_default_nowrap_functions(): + if func in _get_my_nowrap_functions(): return ret else: pg = _scan_for_pg_from_args(args, kwargs) @@ -170,8 +181,9 @@ class ColoTensor(torch.Tensor): Args: dist_spec (_DistSpec): the target dist. spec. """ + assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted" with DistSpecManager.no_grad(): - self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group) + self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group) self.dist_spec = dist_spec def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor': @@ -182,8 +194,7 @@ class ColoTensor(torch.Tensor): """to_replicate_ an inline member function, converting dist spec of the tensor to REPLICATE """ - self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group) - self.dist_spec = distspec.replicate() + self._convert_to_dist_spec(dist_spec=distspec.replicate()) def to_replicate(self) -> 'ColoTensor': """to_replicate @@ -223,12 +234,8 @@ class ColoTensor(torch.Tensor): """ if self.is_replicate(): return super().view(*args) - # TODO(jiaruifang) check why this not work - # self.data = self.to_replicate() - self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(), - self.process_group) - self.dist_spec = distspec.replicate() - return super().view(*args) + replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate()) + return replicated_t.view(*args) def size_global(self, args: Optional[int] = None): """override the torch buildin size() diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index 3c763562f..4bf035d26 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -86,6 +86,7 @@ def _run_tensor_shard_init(world_size): tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t.set_dist_spec(distspec.replicate()) + assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"