diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 9dead5b4b..3541e6a0a 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,4 +1,4 @@ -from .spec import ComputePattern, ParallelAction, TensorSpec +from .spec import ComputePattern, ParallelAction, TensorSpec, ShardPattern from .op_wrapper import ( colo_op_impl,) from .colo_tensor import ColoTensor @@ -7,5 +7,5 @@ from ._ops import * __all__ = [ 'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', - 'named_params_with_colotensor' + 'named_params_with_colotensor', 'ShardPattern' ] diff --git a/colossalai/tensor/_ops/layernorm.py b/colossalai/tensor/_ops/layernorm.py index 6658c05b1..b59d5a00b 100644 --- a/colossalai/tensor/_ops/layernorm.py +++ b/colossalai/tensor/_ops/layernorm.py @@ -27,6 +27,8 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None): eps = kwargs['eps'] if isinstance(input_tensor, ColoTensor): + if input_tensor.is_activation() and not input_tensor.is_gathered(): + input_tensor.gather() input_tensor = input_tensor.torch_tensor() if isinstance(weight, ColoTensor): weight = weight.torch_tensor() diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index 560f391e6..552492add 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -6,9 +6,75 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward from colossalai.nn.layer.utils import divide from colossalai.core import global_context as gpc from packaging import version -from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor +from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern +def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor: + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) + # Input:S[1] x Weight:S[0] = Output:P + # All-Reduce(Output) + bias = res + # Input:S[1] + if input_tensor.is_gathered(): + # Not splited yet. + assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \ + 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size) + input_per_partition = split_forward_gather_backward(input_tensor.torch_tensor(), parallel_action.parallel_mode, dim=-1) + elif input_tensor.shard_pattern == ShardPattern.Col: + # Splited by 1Dcol + assert input_tensor.shape[-1] == weight.size(-1), \ + 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_tensor.shape, weight.size, weight.size(-1)) + input_per_partition = input_tensor.torch_tensor() + else: + raise NotImplementedError + + # Output:P + partial_output = torch.nn.functional.linear(input_per_partition, weight.torch_tensor()) + # Reduce(Output) + output = reduce_input(partial_output, parallel_action.parallel_mode) + # Bias + if bias is not None: + assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op' + output = output + bias.torch_tensor() + output = ColoTensor.init_from_torch_tensor(output) + return output + +def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor: + # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] + # All-Gather(Output) + # Input:B + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) + if input_tensor.is_gathered(): + # Not splited yet. + assert input_tensor.shape[-1] == weight.size(-1), \ + 'Invalid shapes in 1Dcol forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_tensor.shape, weight.size, weight.size(-1)) + input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode) + + # Bias:S[1] + if bias is not None: + assert bias.has_spec() and bias.shard_spec.num_action == 1 and \ + bias.shard_pattern in [ShardPattern.Col, ShardPattern.Row], \ + 'Invalid bias spec for 1Dcol Linear op' + + output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias.torch_tensor()) + + output = ColoTensor.init_from_torch_tensor(output_parallel) + out_parallel_action_list = [ + ParallelAction( + priority=1, compute_pattern=ComputePattern.Activation, + parallel_mode=parallel_action.parallel_mode + ) + ] + output_spec = TensorSpec(out_parallel_action_list) + output.set_spec(output_spec, shard=False) + output.set_shard_pattern(ShardPattern.Col) + if parallel_action.gather_out: + # All-Gather(Output) + output.gather() + return output + @colo_op_impl(torch.nn.functional.linear) def colo_linear(types, args, kwargs, pg): """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. @@ -25,110 +91,29 @@ def colo_linear(types, args, kwargs, pg): else: bias = kwargs.get('bias', None) - bias_spec = None - if isinstance(bias, ColoTensor): - bias_spec = bias.shard_spec - bias = bias.torch_tensor() + if not isinstance(input_tensor, ColoTensor): + input_tensor = ColoTensor.init_from_torch_tensor(input_tensor) + + if not isinstance(weight, ColoTensor): + weight = ColoTensor.init_from_torch_tensor(weight) + + if bias is not None and not isinstance(bias, ColoTensor): + bias = ColoTensor.init_from_torch_tensor(bias) # Add communication logic before and after linear call. - if isinstance(weight, ColoTensor): - if weight.shard_spec == None or weight.shard_spec.num_action == 0: - assert bias_spec == None or bias_spec.num_action == 0, 'Invalid bias spec for native Linear op' - if isinstance(input_tensor, ColoTensor): - input_tensor = input_tensor.torch_tensor() - if isinstance(weight, ColoTensor): - weight = weight.torch_tensor() - return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) - elif weight.shard_spec.num_action == 1: - parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) - compute_patterns = weight.shard_spec.compute_patterns - if ComputePattern.TP1DRow in compute_patterns: - # Input:S[1] x Weight:S[0] = Output:P - # All-Reduce(Output) + bias = res - # Input:S[1] - input_spec = None - if isinstance(input_tensor, ColoTensor): - input_spec = input_tensor.shard_spec - input_tensor = input_tensor.torch_tensor() - - if input_spec == None or input_spec.num_action == 0: - # Not splited yet. - assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \ - 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size) - input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1) - elif input_tensor.shard_spec.num_action == 1: - if ComputePattern.TP1DCol in input_spec.compute_patterns: - # Splited by 1Dcol - assert input_tensor.shape[-1] == weight.size(-1), \ - 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_tensor.shape, weight.size, weight.size(-1)) - input_per_partition = input_tensor - else: - raise NotImplementedError - else: - raise NotImplementedError - - # Output:P - weight_ = weight.torch_tensor() - partial_output = torch.nn.functional.linear(input_per_partition, weight_) - # Reduce(Output) - output = reduce_input(partial_output, parallel_action.parallel_mode) - # Bias - if bias is not None: - assert bias_spec == None or bias_spec.num_action == 0, 'Invalid bias spec for 1Drow Linear op' - output = output + bias - output = ColoTensor.init_from_torch_tensor(output) - return output - elif ComputePattern.TP1DCol in compute_patterns: - # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] - # All-Gather(Output) - # Input:B - input_spec = None - output_spec = None - parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) - if isinstance(input_tensor, ColoTensor): - input_spec = input_tensor.shard_spec - input_tensor = input_tensor.torch_tensor() - - if input_spec == None or input_spec.num_action == 0: - # Not splited yet. - assert input_tensor.shape[-1] == weight.size(-1), \ - 'Invalid shapes in 1Dcol forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_tensor.shape, weight.size, weight.size(-1)) - input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode) - else: - raise NotImplementedError - # Bias:S[1] - if bias is not None: - assert bias_spec is not None and bias_spec.num_action == 1 and \ - ComputePattern.TP1DCol in bias_spec.compute_patterns, \ - 'Invalid bias spec for 1Dcol Linear op' - - weight_ = weight.torch_tensor() - output_parallel = torch.nn.functional.linear(input_parallel, weight_, bias) - - if parallel_action.gather_out: - # All-Gather(Output) - output = gather_forward_split_backward(output_parallel, parallel_action.parallel_mode, dim=-1) - output = ColoTensor.init_from_torch_tensor(output) - else: - output = ColoTensor.init_from_torch_tensor(output_parallel) - out_parallel_action_list = [ - ParallelAction( - priority=1, compute_pattern=ComputePattern.TP1DCol, - parallel_mode=parallel_action.parallel_mode - ) - ] - output_spec = TensorSpec(out_parallel_action_list) - # set ColoTensor spec - if output_spec is not None: - output.set_spec(output_spec) - return output - - else: - raise NotImplementedError + if not weight.has_spec(): # No Model Parallel Applied + assert not bias.has_spec(), 'Invalid bias spec for native Linear op' + input_tensor = input_tensor.torch_tensor() + weight = weight.torch_tensor() + bias = bias.torch_tensor() + return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) + elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied + compute_patterns = weight.shard_spec.compute_patterns + if ComputePattern.TP1DRow in compute_patterns: + return colo_linear_1Drow(input_tensor, weight, bias) + elif ComputePattern.TP1DCol in compute_patterns: + return colo_linear_1Dcol(input_tensor, weight, bias) else: raise NotImplementedError else: - return torch.nn.functional.linear(input_tensor, weight, bias) + raise NotImplementedError diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 3abd71621..f476b354e 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,4 +1,3 @@ -from colossalai.context import parallel_mode from .op_wrapper import _COLOSSAL_OPS import torch @@ -6,8 +5,8 @@ from typing import Tuple, Optional, Callable from numpy import product from colossalai.core import global_context as gpc from colossalai.nn.layer.utils import divide -from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction - +from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern +from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward class ColoTensor(object): """ Data Structure for Tensor in Colossal-AI @@ -37,6 +36,7 @@ class ColoTensor(object): self._device = device self._torch_tensor = torch_tensor self._shard_spec = shard_spec + self._shard_pattern = ShardPattern.NA def __getitem__(self, key): return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key]) @@ -45,6 +45,10 @@ class ColoTensor(object): def shard_spec(self) -> TensorSpec: return self._shard_spec + @property + def shard_pattern(self): + return self._shard_pattern + @property def data(self): return self._torch_tensor.data @@ -112,22 +116,51 @@ class ColoTensor(object): device=self._device) return self._torch_tensor - def set_spec(self, spec: TensorSpec, lazy_shard: bool = False) -> None: + def set_spec(self, spec: TensorSpec, shard: bool = True) -> None: self._shard_spec = spec - if lazy_shard == False: - self._shard() + if shard == True: + self.shard() + + def set_shard_pattern(self, shard_pattern: ShardPattern): + self._shard_pattern = shard_pattern - def _shard(self): + def shard(self): assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.' - if self._shard_spec.num_action == 1: - if ComputePattern.TP1DRow in self._shard_spec.compute_patterns: - parallel_action = self._shard_spec.get_action_by_compute_pattern( - ComputePattern.TP1DRow) - self._shard_1d(parallel_action=parallel_action, dim=-1) - elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns: - parallel_action = self._shard_spec.get_action_by_compute_pattern( - ComputePattern.TP1DCol) - self._shard_1d(parallel_action=parallel_action, dim=0) + if self._shard_pattern is not ShardPattern.NA: # reshard + self.gather() + # Model Parameters + if ComputePattern.TP1DRow in self._shard_spec.compute_patterns: + parallel_action = self._shard_spec.get_action_by_compute_pattern( + ComputePattern.TP1DRow) + self._shard_1d(parallel_action=parallel_action, dim=-1) + self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear(). + elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns: + parallel_action = self._shard_spec.get_action_by_compute_pattern( + ComputePattern.TP1DCol) + self._shard_1d(parallel_action=parallel_action, dim=0) + self._shard_pattern = ShardPattern.Row + + def gather(self): + assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.' + assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.' + parallel_action = self._shard_spec.get_action_by_compute_pattern( + ComputePattern.Activation) + if self._shard_pattern == ShardPattern.Row: + dim = 0 + elif self._shard_pattern == ShardPattern.Col: + dim = -1 + self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim) + self._shard_pattern = ShardPattern.NA + + def is_gathered(self) -> bool: + return self._shard_pattern == ShardPattern.NA + + def has_spec(self) -> bool: + return self._shard_spec is not None and self._shard_spec.num_action > 0 + + def is_activation(self) -> bool: + return self._shard_spec is not None and self._shard_spec.num_action == 1 \ + and ComputePattern.Activation in self._shard_spec.compute_patterns def _shard_1d(self, parallel_action, dim=-1): num_partition = gpc.get_world_size(parallel_action.parallel_mode) diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index 14d4d2099..2584bc224 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -4,11 +4,16 @@ from colossalai.context.parallel_mode import ParallelMode class ComputePattern(Enum): + Activation = 0 # TODO(jzy) A tmp place to store Activation info. Find a better place in future. TP1DRow = 1 TP1DCol = 2 ZeRO = 3 DP = 4 +class ShardPattern(Enum): + NA = 0 + Row = 1 + Col = 2 class ParallelAction(object): @@ -18,6 +23,7 @@ class ParallelAction(object): self.parallel_mode = parallel_mode self.gather_out = gather_out + class TensorSpec(object): """ It contains two aspects of information: @@ -42,8 +48,9 @@ class TensorSpec(object): # We perform Linear Op according to compute pattern of TP1DRow. # After Linear Op, we split the tensors according to ZeRO. - def __init__(self, parallel_action_list: List[ParallelAction] = []): + def __init__(self, parallel_action_list: List[ParallelAction] = [], shard_pattern: ShardPattern = ShardPattern.NA): self._parallel_action_list = parallel_action_list + self._shard_pattern = shard_pattern self.sort() @property @@ -57,6 +64,10 @@ class TensorSpec(object): @property def compute_patterns(self): return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list] + + @property + def shard_pattern(self): + return self._shard_pattern def sort(self): if len(self._parallel_action_list) > 0: diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index 1035e9c3b..49de72012 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -145,7 +145,7 @@ def run_linear_tp1d_row_test(): def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_linear_tp1d_row_test() + #run_linear_tp1d_row_test() run_linear_tp1d_col_test() @pytest.mark.dist diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 80029eabd..cc0bf73b5 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -26,6 +26,77 @@ def set_seed(seed): torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True +def run_1d_col_tp(): + # A simple net with two stacked nn.Linear + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + parallel_action_list_row = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec_row = TensorSpec(parallel_action_list_row) + + parallel_action_list_col = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec_col = TensorSpec(parallel_action_list_col) + + set_seed(1) + if rank == 0: + model_torch = model_builder(checkpoint=True) + model_torch = model_torch.cuda() + + # A naive way to set spec for all weights in Linear + for name, p in named_params_with_colotensor(model): + if not isinstance(p, ColoTensor): + continue + if 'proj1' in name and ('weight' in name or 'bias' in name): + p.set_spec(spec_col) + if 'proj2' in name and 'weight' in name: + p.set_spec(spec_row) + + model = model.cuda() + + for i, (data, label) in enumerate(train_dataloader): + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + # For reference + if rank == 0: + if criterion: + output_torch = model_torch(data) + loss_torch = criterion(output_torch, label) + else: + output_torch = model_torch(data, label) + loss_torch = output_torch + + if rank == 0: + # print(loss.torch_tensor().item()) + # print('loss torch', loss_torch.item()) + assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2) + + loss.backward() + + if rank == 0: + loss_torch.backward() + if i > 5: + break def run_1d_row_tp(): # A simple net with two stacked nn.Linear