[tensor] redirect .data.__get__ to a tensor instance (#1239)

pull/1243/head
HELSON 2022-07-11 11:41:29 +08:00 committed by GitHub
parent 20da6e48c8
commit f6add9b720
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 12 deletions

View File

@ -2,13 +2,24 @@ from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType from .const import TensorType
from copy import copy from copy import copy
import torch import torch
from torch.overrides import get_default_nowrap_functions from functools import lru_cache
from colossalai.tensor import ColoTensorSpec from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional from typing import Optional, Set, Callable
@lru_cache(None)
def _get_my_nowrap_functions() -> Set[Callable]:
Tensor = torch.Tensor
return {
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
}
def _convert_output(output, pg: ProcessGroup): def _convert_output(output, pg: ProcessGroup):
@ -154,7 +165,7 @@ class ColoTensor(torch.Tensor):
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
if func in get_default_nowrap_functions(): if func in _get_my_nowrap_functions():
return ret return ret
else: else:
pg = _scan_for_pg_from_args(args, kwargs) pg = _scan_for_pg_from_args(args, kwargs)
@ -170,8 +181,9 @@ class ColoTensor(torch.Tensor):
Args: Args:
dist_spec (_DistSpec): the target dist. spec. dist_spec (_DistSpec): the target dist. spec.
""" """
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group) self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec self.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor': def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
@ -182,8 +194,7 @@ class ColoTensor(torch.Tensor):
"""to_replicate_ """to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE an inline member function, converting dist spec of the tensor to REPLICATE
""" """
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group) self._convert_to_dist_spec(dist_spec=distspec.replicate())
self.dist_spec = distspec.replicate()
def to_replicate(self) -> 'ColoTensor': def to_replicate(self) -> 'ColoTensor':
"""to_replicate """to_replicate
@ -223,12 +234,8 @@ class ColoTensor(torch.Tensor):
""" """
if self.is_replicate(): if self.is_replicate():
return super().view(*args) return super().view(*args)
# TODO(jiaruifang) check why this not work replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
# self.data = self.to_replicate() return replicated_t.view(*args)
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
self.process_group)
self.dist_spec = distspec.replicate()
return super().view(*args)
def size_global(self, args: Optional[int] = None): def size_global(self, args: Optional[int] = None):
"""override the torch buildin size() """override the torch buildin size()

View File

@ -86,6 +86,7 @@ def _run_tensor_shard_init(world_size):
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr) tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate()) t.set_dist_spec(distspec.replicate())
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})" assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"