|
|
|
@ -7,6 +7,13 @@ from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.nn.layer.utils import divide |
|
|
|
|
from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern |
|
|
|
|
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward |
|
|
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorType(Enum): |
|
|
|
|
MODEL = 0 |
|
|
|
|
NONMODEL = 1 # mainly activations |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ColoTensor(object): |
|
|
|
|
""" Data Structure for Tensor in Colossal-AI |
|
|
|
@ -28,6 +35,7 @@ class ColoTensor(object):
|
|
|
|
|
device=None, |
|
|
|
|
torch_tensor=torch.empty(0), |
|
|
|
|
shard_spec: TensorSpec = TensorSpec(), |
|
|
|
|
is_model_data: bool = False, |
|
|
|
|
): |
|
|
|
|
self._size = size |
|
|
|
|
self._dtype = dtype |
|
|
|
@ -37,6 +45,10 @@ class ColoTensor(object):
|
|
|
|
|
self._torch_tensor = torch_tensor |
|
|
|
|
self._shard_spec = shard_spec |
|
|
|
|
self._shard_pattern = ShardPattern.NA |
|
|
|
|
if is_model_data: |
|
|
|
|
self._type = TensorType.MODEL |
|
|
|
|
else: |
|
|
|
|
self._type = TensorType.NONMODEL |
|
|
|
|
|
|
|
|
|
def __getitem__(self, key): |
|
|
|
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key]) |
|
|
|
@ -85,13 +97,14 @@ class ColoTensor(object):
|
|
|
|
|
return product(self._size) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor': |
|
|
|
|
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True, is_model_data=False) -> 'ColoTensor': |
|
|
|
|
colo_t = ColoTensor(*tensor.size(), |
|
|
|
|
dtype=tensor.dtype, |
|
|
|
|
requires_grad=tensor.requires_grad, |
|
|
|
|
pin_memory=tensor.is_pinned(), |
|
|
|
|
device=tensor.device, |
|
|
|
|
torch_tensor=tensor if save_payload else torch.empty(0)) |
|
|
|
|
torch_tensor=tensor if save_payload else torch.empty(0), |
|
|
|
|
is_model_data=is_model_data) |
|
|
|
|
return colo_t |
|
|
|
|
|
|
|
|
|
def del_torch_tensor(self, save_shape=False) -> None: |
|
|
|
@ -120,31 +133,28 @@ class ColoTensor(object):
|
|
|
|
|
self._shard_spec = spec |
|
|
|
|
if shard == True: |
|
|
|
|
self.shard() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_shard_pattern(self, shard_pattern: ShardPattern): |
|
|
|
|
self._shard_pattern = shard_pattern |
|
|
|
|
|
|
|
|
|
def shard(self): |
|
|
|
|
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.' |
|
|
|
|
if self._shard_pattern is not ShardPattern.NA: # reshard |
|
|
|
|
if self._shard_pattern is not ShardPattern.NA: # reshard |
|
|
|
|
self.gather() |
|
|
|
|
# Model Parameters |
|
|
|
|
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns: |
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern( |
|
|
|
|
ComputePattern.TP1DRow) |
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) |
|
|
|
|
self._shard_1d(parallel_action=parallel_action, dim=-1) |
|
|
|
|
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear(). |
|
|
|
|
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear(). |
|
|
|
|
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns: |
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern( |
|
|
|
|
ComputePattern.TP1DCol) |
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) |
|
|
|
|
self._shard_1d(parallel_action=parallel_action, dim=0) |
|
|
|
|
self._shard_pattern = ShardPattern.Row |
|
|
|
|
|
|
|
|
|
def gather(self): |
|
|
|
|
assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.' |
|
|
|
|
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.' |
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern( |
|
|
|
|
ComputePattern.Activation) |
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP) |
|
|
|
|
if self._shard_pattern == ShardPattern.Row: |
|
|
|
|
dim = 0 |
|
|
|
|
elif self._shard_pattern == ShardPattern.Col: |
|
|
|
@ -159,9 +169,8 @@ class ColoTensor(object):
|
|
|
|
|
return self._shard_spec is not None and self._shard_spec.num_action > 0 |
|
|
|
|
|
|
|
|
|
def is_activation(self) -> bool: |
|
|
|
|
return self._shard_spec is not None and self._shard_spec.num_action == 1 \ |
|
|
|
|
and ComputePattern.Activation in self._shard_spec.compute_patterns |
|
|
|
|
|
|
|
|
|
return self._type == TensorType.NONMODEL |
|
|
|
|
|
|
|
|
|
def _shard_1d(self, parallel_action, dim=-1): |
|
|
|
|
num_partition = gpc.get_world_size(parallel_action.parallel_mode) |
|
|
|
|
local_rank = gpc.get_local_rank(parallel_action.parallel_mode) |
|
|
|
@ -169,8 +178,8 @@ class ColoTensor(object):
|
|
|
|
|
# Reshape to get shard for this rank and we don't want autograd |
|
|
|
|
# recording here for the narrow op and 'local_shard' should be a |
|
|
|
|
# leaf variable in the autograd graph. |
|
|
|
|
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach( |
|
|
|
|
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor? |
|
|
|
|
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach().contiguous( |
|
|
|
|
) # TODO Shall we clone() here since detach() will point to the old tensor? |
|
|
|
|
self._torch_tensor.requires_grad = self._requires_grad |
|
|
|
|
self._size = self._torch_tensor.size() |
|
|
|
|
|
|
|
|
|