mirror of https://github.com/hpcaitech/ColossalAI
[tensor] redirect .data.__get__ to a tensor instance (#1239)
parent
20da6e48c8
commit
f6add9b720
|
@ -2,13 +2,24 @@ from .op_wrapper import _COLOSSAL_OPS
|
||||||
from .const import TensorType
|
from .const import TensorType
|
||||||
from copy import copy
|
from copy import copy
|
||||||
import torch
|
import torch
|
||||||
from torch.overrides import get_default_nowrap_functions
|
from functools import lru_cache
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensorSpec
|
from colossalai.tensor import ColoTensorSpec
|
||||||
from colossalai.tensor import distspec, ProcessGroup
|
from colossalai.tensor import distspec, ProcessGroup
|
||||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||||
from typing import Optional
|
from typing import Optional, Set, Callable
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(None)
|
||||||
|
def _get_my_nowrap_functions() -> Set[Callable]:
|
||||||
|
Tensor = torch.Tensor
|
||||||
|
return {
|
||||||
|
Tensor._base.__get__,
|
||||||
|
Tensor.grad.__get__,
|
||||||
|
Tensor._grad.__get__,
|
||||||
|
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _convert_output(output, pg: ProcessGroup):
|
def _convert_output(output, pg: ProcessGroup):
|
||||||
|
@ -154,7 +165,7 @@ class ColoTensor(torch.Tensor):
|
||||||
|
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
ret = func(*args, **kwargs)
|
ret = func(*args, **kwargs)
|
||||||
if func in get_default_nowrap_functions():
|
if func in _get_my_nowrap_functions():
|
||||||
return ret
|
return ret
|
||||||
else:
|
else:
|
||||||
pg = _scan_for_pg_from_args(args, kwargs)
|
pg = _scan_for_pg_from_args(args, kwargs)
|
||||||
|
@ -170,8 +181,9 @@ class ColoTensor(torch.Tensor):
|
||||||
Args:
|
Args:
|
||||||
dist_spec (_DistSpec): the target dist. spec.
|
dist_spec (_DistSpec): the target dist. spec.
|
||||||
"""
|
"""
|
||||||
|
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||||
self.dist_spec = dist_spec
|
self.dist_spec = dist_spec
|
||||||
|
|
||||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||||
|
@ -182,8 +194,7 @@ class ColoTensor(torch.Tensor):
|
||||||
"""to_replicate_
|
"""to_replicate_
|
||||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||||
"""
|
"""
|
||||||
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group)
|
self._convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||||
self.dist_spec = distspec.replicate()
|
|
||||||
|
|
||||||
def to_replicate(self) -> 'ColoTensor':
|
def to_replicate(self) -> 'ColoTensor':
|
||||||
"""to_replicate
|
"""to_replicate
|
||||||
|
@ -223,12 +234,8 @@ class ColoTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
if self.is_replicate():
|
if self.is_replicate():
|
||||||
return super().view(*args)
|
return super().view(*args)
|
||||||
# TODO(jiaruifang) check why this not work
|
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||||
# self.data = self.to_replicate()
|
return replicated_t.view(*args)
|
||||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
|
|
||||||
self.process_group)
|
|
||||||
self.dist_spec = distspec.replicate()
|
|
||||||
return super().view(*args)
|
|
||||||
|
|
||||||
def size_global(self, args: Optional[int] = None):
|
def size_global(self, args: Optional[int] = None):
|
||||||
"""override the torch buildin size()
|
"""override the torch buildin size()
|
||||||
|
|
|
@ -86,6 +86,7 @@ def _run_tensor_shard_init(world_size):
|
||||||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||||
t.set_dist_spec(distspec.replicate())
|
t.set_dist_spec(distspec.replicate())
|
||||||
|
|
||||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue