diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 157da5db6..143eeae58 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,7 +1,9 @@ +from .spec import ComputePattern, ParallelAction, TensorSpec from .op_wrapper import ( colo_op_impl,) from .colo_tensor import ColoTensor from .utils import convert_parameter from ._ops import * -__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl'] +__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', + 'TensorSpec', 'ParallelAction'] diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py index 034a2f695..7438d6ef7 100644 --- a/colossalai/tensor/_ops/__init__.py +++ b/colossalai/tensor/_ops/__init__.py @@ -2,4 +2,4 @@ from .init import colo_uniform from .linear import colo_linear from .element_wise import colo_mean from .layernorm import colo_layernorm -from .loss import colo_cross_entropy \ No newline at end of file +from .loss import colo_cross_entropy diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index c6bb78dd4..519678480 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -6,8 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward from colossalai.nn.layer.utils import divide from colossalai.core import global_context as gpc from packaging import version -from colossalai.utils.cuda import get_current_device - +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction @colo_op_impl(torch.nn.functional.linear) def colo_linear(types, args, kwargs, pg): @@ -30,32 +29,36 @@ def colo_linear(types, args, kwargs, pg): # Add communication logic before and after linear call. if isinstance(weight, ColoTensor): - if weight.shard_spec == None: + if weight.shard_spec == None or weight.shard_spec.num_action == 0: if isinstance(input_tensor, ColoTensor): input_tensor = input_tensor.torch_tensor() if isinstance(weight, ColoTensor): weight = weight.torch_tensor() return torch.nn.functional.linear(input_tensor, weight, bias) - elif weight.shard_spec == '1Drow': - # Input:S[1] x Weight:S[0] = Output:P - # All-Reduce(Output) + bias = res - assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \ - 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_tensor.shape, weight.size, weight.size[-1] * gpc.tensor_parallel_size) - # Input:S[1] - input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1) - # Output:P - device = get_current_device() # TODO where to put to(deivce)? - weight_ = weight.torch_tensor().to(device) - partial_output = torch.nn.functional.linear(input_per_partition, weight_) - # Reduce(Output) - output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) - # Bias - if bias is not None: - bias_ = bias.to(device) - output = output + bias_ - return output - + elif weight.shard_spec.num_action == 1: + if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns: + # Input:S[1] x Weight:S[0] = Output:P + # All-Reduce(Output) + bias = res + assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \ + 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size) + # Input:S[1] + if isinstance(input_tensor, ColoTensor): + input_tensor = input_tensor.torch_tensor() + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) + input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1) + # Output:P + weight_ = weight.torch_tensor() + partial_output = torch.nn.functional.linear(input_per_partition, weight_) + # Reduce(Output) + output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) + # Bias + if bias is not None: + bias_ = bias + output = output + bias_ + return ColoTensor.init_from_torch_tensor(output) + else: + raise NotImplementedError else: raise NotImplementedError else: diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 8d67d6f69..1824f0b49 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,13 +1,12 @@ +from colossalai.context import parallel_mode from .op_wrapper import _COLOSSAL_OPS import torch from typing import Tuple, Optional from numpy import product from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode from colossalai.nn.layer.utils import divide -from colossalai.utils.cuda import get_current_device - +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction class ColoTensor(object): """ Data Structure for Tensor in Colossal-AI @@ -28,7 +27,7 @@ class ColoTensor(object): pin_memory=False, device=None, torch_tensor=torch.empty(0), - shard_spec: str = None, + shard_spec: TensorSpec = TensorSpec(), ): self._size = size self._dtype = dtype @@ -39,7 +38,7 @@ class ColoTensor(object): self._shard_spec = shard_spec @property - def shard_spec(self) -> Optional[str]: + def shard_spec(self) -> TensorSpec: return self._shard_spec @property @@ -109,27 +108,27 @@ class ColoTensor(object): device=self._device) return self._torch_tensor - def set_spec(self, spec: str, lazy_shard: bool = False) -> None: + def set_spec(self, spec: TensorSpec, lazy_shard: bool = False) -> None: self._shard_spec = spec if lazy_shard == False: self._shard() def _shard(self): assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.' - if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now. - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - local_rank = gpc.get_local_rank(ParallelMode.TENSOR) - dim = -1 - chunk_size = divide(self._size[dim], num_partition) - device = get_current_device() - # Reshape to get shard for this rank and we don't want autograd - # recording here for the narrow op and 'local_shard' should be a - # leaf variable in the autograd graph. - self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach( - ).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor? - self._torch_tensor.requires_grad = self._requires_grad - self._size = self._torch_tensor.size() - self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu + if self._shard_spec.num_action == 1: + if ComputePattern.TP1DRow in self._shard_spec.compute_patterns: + parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) + num_partition = gpc.get_world_size(parallel_action.parallel_mode) + local_rank = gpc.get_local_rank(parallel_action.parallel_mode) + dim = -1 + chunk_size = divide(self._size[dim], num_partition) + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach( + ).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor? + self._torch_tensor.requires_grad = self._requires_grad + self._size = self._torch_tensor.size() @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -151,5 +150,5 @@ class ColoTensor(object): kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} return func(*args, **kwargs) - def backward(self, retain_graph: bool = False): - self._torch_tensor.backward(retain_graph=retain_graph) + def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False): + self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph) diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index 8339c50c6..ccd85d9cb 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -1,8 +1,6 @@ from enum import Enum from typing import Tuple, List from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc - class ComputePattern(Enum): TP1DRow = 1 @@ -12,17 +10,13 @@ class ComputePattern(Enum): class ParallelAction(object): - priority = 0 - compute_pattern = ComputePattern.DP - process_group = gpc.get_group(ParallelMode.DATA) - - def __init__(self, priority, compute_pattern, process_group) -> None: + def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None: self.priority = priority self.compute_pattern = compute_pattern - self.process_group = process_group + self.parallel_mode = parallel_mode -class TensorSpec(Enum): +class TensorSpec(object): """ It contains two aspects of information: First, How are tensors distributed in Heterougenous memory space. @@ -44,4 +38,28 @@ class TensorSpec(Enum): # Before Linear Op, we gather the tensors according to ZeRO. # We perform Linear Op according to compute pattern of TP1DRow. # After Linear Op, we split the tensors according to ZeRO. - parallel_action_list: List[ParallelAction] = [] + def __init__(self, parallel_action_list: List[ParallelAction] = []): + self._parallel_action_list = parallel_action_list + self.sort() + + @property + def parallel_action_list(self): + return self._parallel_action_list + + @property + def num_action(self): + return len(self._parallel_action_list) + + @property + def compute_patterns(self): + return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list] + + def sort(self): + if len(self._parallel_action_list) > 0: + self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority) + + def get_action_by_compute_pattern(self, compute_pattern: ComputePattern): + for parallel_action in self._parallel_action_list: + if parallel_action.compute_pattern == compute_pattern: + return parallel_action + return None diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index 760818efc..335b17cf5 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -12,6 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.core import global_context as gpc +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk @@ -45,7 +46,11 @@ def run_linear_tp1d_row_test(): # replace the torch nn.Parameters with ColoTensor sharded_weight = ColoTensor.init_from_torch_tensor(W) - sharded_weight.set_spec(spec="1Drow") # reshard + parallel_action_list = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec = TensorSpec(parallel_action_list) + sharded_weight.set_spec(spec=spec) # reshard sharded_bias = ColoTensor.init_from_torch_tensor(B) replace_parameter_add_grad(layer, sharded_weight, sharded_bias) out = layer(A)