diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index 8ca80b4ca..560f391e6 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -1,12 +1,12 @@ import torch from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor.colo_tensor import ColoTensor from colossalai.context import ParallelMode -from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input +from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, \ + gather_forward_split_backward, reduce_grad from colossalai.nn.layer.utils import divide from colossalai.core import global_context as gpc from packaging import version -from colossalai.tensor import ComputePattern +from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor @colo_op_impl(torch.nn.functional.linear) @@ -25,39 +25,107 @@ def colo_linear(types, args, kwargs, pg): else: bias = kwargs.get('bias', None) + bias_spec = None if isinstance(bias, ColoTensor): - assert bias.shard_spec.num_action == 0, f"We currently only support bias is duplicated among processes in the linear operator" + bias_spec = bias.shard_spec bias = bias.torch_tensor() - + # Add communication logic before and after linear call. if isinstance(weight, ColoTensor): if weight.shard_spec == None or weight.shard_spec.num_action == 0: + assert bias_spec == None or bias_spec.num_action == 0, 'Invalid bias spec for native Linear op' if isinstance(input_tensor, ColoTensor): input_tensor = input_tensor.torch_tensor() if isinstance(weight, ColoTensor): weight = weight.torch_tensor() return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) elif weight.shard_spec.num_action == 1: - if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns: + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) + compute_patterns = weight.shard_spec.compute_patterns + if ComputePattern.TP1DRow in compute_patterns: # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res - assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \ - 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size) # Input:S[1] + input_spec = None if isinstance(input_tensor, ColoTensor): + input_spec = input_tensor.shard_spec input_tensor = input_tensor.torch_tensor() - parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) - input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1) + + if input_spec == None or input_spec.num_action == 0: + # Not splited yet. + assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \ + 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size) + input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1) + elif input_tensor.shard_spec.num_action == 1: + if ComputePattern.TP1DCol in input_spec.compute_patterns: + # Splited by 1Dcol + assert input_tensor.shape[-1] == weight.size(-1), \ + 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_tensor.shape, weight.size, weight.size(-1)) + input_per_partition = input_tensor + else: + raise NotImplementedError + else: + raise NotImplementedError + # Output:P weight_ = weight.torch_tensor() partial_output = torch.nn.functional.linear(input_per_partition, weight_) # Reduce(Output) - output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) + output = reduce_input(partial_output, parallel_action.parallel_mode) # Bias if bias is not None: + assert bias_spec == None or bias_spec.num_action == 0, 'Invalid bias spec for 1Drow Linear op' output = output + bias - return ColoTensor.init_from_torch_tensor(output) + output = ColoTensor.init_from_torch_tensor(output) + return output + elif ComputePattern.TP1DCol in compute_patterns: + # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] + # All-Gather(Output) + # Input:B + input_spec = None + output_spec = None + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) + if isinstance(input_tensor, ColoTensor): + input_spec = input_tensor.shard_spec + input_tensor = input_tensor.torch_tensor() + + if input_spec == None or input_spec.num_action == 0: + # Not splited yet. + assert input_tensor.shape[-1] == weight.size(-1), \ + 'Invalid shapes in 1Dcol forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_tensor.shape, weight.size, weight.size(-1)) + input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode) + else: + raise NotImplementedError + # Bias:S[1] + if bias is not None: + assert bias_spec is not None and bias_spec.num_action == 1 and \ + ComputePattern.TP1DCol in bias_spec.compute_patterns, \ + 'Invalid bias spec for 1Dcol Linear op' + + weight_ = weight.torch_tensor() + output_parallel = torch.nn.functional.linear(input_parallel, weight_, bias) + + if parallel_action.gather_out: + # All-Gather(Output) + output = gather_forward_split_backward(output_parallel, parallel_action.parallel_mode, dim=-1) + output = ColoTensor.init_from_torch_tensor(output) + else: + output = ColoTensor.init_from_torch_tensor(output_parallel) + out_parallel_action_list = [ + ParallelAction( + priority=1, compute_pattern=ComputePattern.TP1DCol, + parallel_mode=parallel_action.parallel_mode + ) + ] + output_spec = TensorSpec(out_parallel_action_list) + # set ColoTensor spec + if output_spec is not None: + output.set_spec(output_spec) + return output + else: raise NotImplementedError else: diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index e22fa5850..89b1835f4 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -121,18 +121,25 @@ class ColoTensor(object): assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.' if self._shard_spec.num_action == 1: if ComputePattern.TP1DRow in self._shard_spec.compute_patterns: - parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) - num_partition = gpc.get_world_size(parallel_action.parallel_mode) - local_rank = gpc.get_local_rank(parallel_action.parallel_mode) - dim = -1 - chunk_size = divide(self._size[dim], num_partition) - # Reshape to get shard for this rank and we don't want autograd - # recording here for the narrow op and 'local_shard' should be a - # leaf variable in the autograd graph. - self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach( - ).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor? - self._torch_tensor.requires_grad = self._requires_grad - self._size = self._torch_tensor.size() + parallel_action = self._shard_spec.get_action_by_compute_pattern( + ComputePattern.TP1DRow) + self._shard_1d(parallel_action=parallel_action, dim=-1) + elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns: + parallel_action = self._shard_spec.get_action_by_compute_pattern( + ComputePattern.TP1DCol) + self._shard_1d(parallel_action=parallel_action, dim=0) + + def _shard_1d(self, parallel_action, dim=-1): + num_partition = gpc.get_world_size(parallel_action.parallel_mode) + local_rank = gpc.get_local_rank(parallel_action.parallel_mode) + chunk_size = divide(self._size[dim], num_partition) + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach( + ).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor? + self._torch_tensor.requires_grad = self._requires_grad + self._size = self._torch_tensor.size() @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index 18099cc0c..14d4d2099 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -12,11 +12,11 @@ class ComputePattern(Enum): class ParallelAction(object): - def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None: + def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA, gather_out=True) -> None: self.priority = priority self.compute_pattern = compute_pattern self.parallel_mode = parallel_mode - + self.gather_out = gather_out class TensorSpec(object): """ diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index 335b17cf5..1035e9c3b 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -16,6 +16,69 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk +def run_linear_tp1d_col_test(): + device = get_current_device() + dtype = torch.float32 + DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D) + in_features = 4 + out_features = 8 + + local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer_master = torch.nn.Linear(in_features, out_features) + layer = torch.nn.Linear(in_features, out_features) + + A_shape = (2, in_features) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + A = broadcast_tensor_chunk(A_master, chunk_size=1) + A.requires_grad = True + + W_shape = (out_features, in_features) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + W = broadcast_tensor_chunk(W_master, chunk_size=1) + W.requires_grad = True + + B_shape = (out_features) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + B = broadcast_tensor_chunk(B_master, chunk_size=1) + B.requires_grad = True + + # replace the torch nn.Parameters with ColoTensor + sharded_weight = ColoTensor.init_from_torch_tensor(W) + sharded_bias = ColoTensor.init_from_torch_tensor(B) + parallel_action_list = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec = TensorSpec(parallel_action_list) + sharded_weight.set_spec(spec) # reshard + sharded_bias.set_spec(spec) + + replace_parameter_add_grad(layer, sharded_weight, sharded_bias) + out = layer(A) + + replace_parameter_add_grad(layer_master, W_master, B_master) + A_master.requires_grad = True + #C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad = broadcast_tensor_chunk(grad_master, chunk_size=1) + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[local_rank] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[local_rank] + check_equal(B_grad, layer.bias.grad) def run_linear_tp1d_row_test(): device = get_current_device() @@ -83,7 +146,7 @@ def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_linear_tp1d_row_test() - + run_linear_tp1d_col_test() @pytest.mark.dist @parameterize('world_size', [1, 4])