mirror of https://github.com/hpcaitech/ColossalAI
[tensor] redirect .data.__get__ to a tensor instance (#1239)
parent
20da6e48c8
commit
f6add9b720
|
@ -2,13 +2,24 @@ from .op_wrapper import _COLOSSAL_OPS
|
|||
from .const import TensorType
|
||||
from copy import copy
|
||||
import torch
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
from functools import lru_cache
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor import distspec, ProcessGroup
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from typing import Optional
|
||||
from typing import Optional, Set, Callable
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def _get_my_nowrap_functions() -> Set[Callable]:
|
||||
Tensor = torch.Tensor
|
||||
return {
|
||||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
def _convert_output(output, pg: ProcessGroup):
|
||||
|
@ -154,7 +165,7 @@ class ColoTensor(torch.Tensor):
|
|||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
if func in get_default_nowrap_functions():
|
||||
if func in _get_my_nowrap_functions():
|
||||
return ret
|
||||
else:
|
||||
pg = _scan_for_pg_from_args(args, kwargs)
|
||||
|
@ -170,8 +181,9 @@ class ColoTensor(torch.Tensor):
|
|||
Args:
|
||||
dist_spec (_DistSpec): the target dist. spec.
|
||||
"""
|
||||
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
|
@ -182,8 +194,7 @@ class ColoTensor(torch.Tensor):
|
|||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group)
|
||||
self.dist_spec = distspec.replicate()
|
||||
self._convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
|
@ -223,12 +234,8 @@ class ColoTensor(torch.Tensor):
|
|||
"""
|
||||
if self.is_replicate():
|
||||
return super().view(*args)
|
||||
# TODO(jiaruifang) check why this not work
|
||||
# self.data = self.to_replicate()
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
|
||||
self.process_group)
|
||||
self.dist_spec = distspec.replicate()
|
||||
return super().view(*args)
|
||||
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
"""override the torch buildin size()
|
||||
|
|
|
@ -86,6 +86,7 @@ def _run_tensor_shard_init(world_size):
|
|||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_dist_spec(distspec.replicate())
|
||||
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue