From 70c58cfd4f81a157693b34694ad443da89a87cc8 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 23 Jun 2023 16:07:09 +0800 Subject: [PATCH] [shardformer] supported fused qkv checkpoint (#4073) --- colossalai/shardformer/layer/_operation.py | 86 +++++++++- colossalai/shardformer/layer/embedding.py | 4 +- colossalai/shardformer/layer/linear.py | 16 +- colossalai/shardformer/layer/linear_conv.py | 162 ++++++++++++------ .../shardformer/layer/parallel_module.py | 8 +- colossalai/tensor/d_tensor/__init__.py | 8 +- colossalai/tensor/d_tensor/api.py | 127 ++++++++++++++ .../test_layer/test_linear_1d.py | 64 ++++++- .../test_layer/test_linearconv_1d.py | 20 ++- .../test_model/test_shard_gpt2.py | 13 +- 10 files changed, 420 insertions(+), 88 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 280d55263..7e97bee01 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +import torch.nn.functional as F try: import fused_mix_prec_layer_norm_cuda @@ -46,6 +47,53 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None +class MatmulWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + class LinearWithAsyncCommunication(torch.autograd.Function): """ Linear layer execution with asynchronous communication in backprop. @@ -58,9 +106,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function): ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce - output = torch.matmul(input_, weight.t()) if bias is not None: - output = output + bias + output = F.linear(input_, weight, bias) + else: + output = F.linear(input_, weight) return output @staticmethod @@ -114,7 +163,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function): return _gather(grad_output, ctx.dim, ctx.process_group), None, None -class _ReduceInput(torch.autograd.Function): +class _ReduceForward(torch.autograd.Function): """ All-reduce the input from the model parallel region. @@ -132,6 +181,25 @@ class _ReduceInput(torch.autograd.Function): return grad_output, None +class _ReduceBackward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group): + ctx.process_group = process_group + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.process_group), None + + def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -198,6 +266,10 @@ class _GatherForwardSplitBackward(torch.autograd.Function): return _split(grad_output, ctx.dim, ctx.process_group), None, None +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) + + def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) @@ -210,5 +282,9 @@ def split_forward_gather_backward(input_, dim, process_group): return _SplitForwardGatherBackward.apply(input_, dim, process_group) -def reduce_input(input_, process_group): - return _ReduceInput.apply(input_, process_group) +def reduce_forward(input_, process_group): + return _ReduceForward.apply(input_, process_group) + + +def reduce_backward(input_, process_group): + return _ReduceBackward.apply(input_, process_group) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 23601a04a..db39a457b 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -15,7 +15,7 @@ from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param -from ._operation import gather_forward_split_backward, reduce_input +from ._operation import gather_forward_split_backward, reduce_forward from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset @@ -276,5 +276,5 @@ class VocabParallelEmbedding1D(ParallelModule): # Mask the output embedding. output_parallel[input_mask, :] = 0. # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group) return output diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 912be26b9..d952d5eec 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -15,12 +15,11 @@ from torch.nn.parameter import Parameter from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param -from colossalai.utils.cuda import get_current_device from ._operation import ( gather_forward_split_backward, linear_with_async_comm, - reduce_input, + reduce_forward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -148,9 +147,10 @@ class Linear1D_Col(ParallelModule): assert input_.shape[-1] == self.weight.shape[-1], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) input_parallel = input_ + # Matrix multiply. bias = self.bias if not self.skip_bias_add else None output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -209,17 +209,14 @@ class Linear1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') # Parameters. # Initialize weight. - if device is None: - device = get_current_device() - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) sharded_weight = shard_colwise(weight, self.process_group) self.weight = sharded_tensor_to_param(sharded_weight) @@ -327,8 +324,7 @@ class Linear1D_Row(ParallelModule): output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: @@ -336,5 +332,3 @@ class Linear1D_Row(ParallelModule): return output else: return output, self.bias - return output, self.bias - return output, self.bias diff --git a/colossalai/shardformer/layer/linear_conv.py b/colossalai/shardformer/layer/linear_conv.py index 2d1dacf2c..e856abc14 100644 --- a/colossalai/shardformer/layer/linear_conv.py +++ b/colossalai/shardformer/layer/linear_conv.py @@ -14,13 +14,18 @@ from torch.nn.parameter import Parameter from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise -from colossalai.utils.cuda import get_current_device +from colossalai.tensor.d_tensor.api import ( + customized_distributed_tensor_to_param, + distribute_tensor_with_customization, + shard_rowwise, + sharded_tensor_to_param, +) from ._operation import ( gather_forward_split_backward, - linear_with_async_comm, - reduce_input, + matmul_with_async_comm, + reduce_backward, + reduce_forward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -29,11 +34,69 @@ from .utils import create_randomizer_with_offset __all__ = ['LinearConv1D_Col', 'LinearConv1D_Row'] +def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup): + """ + The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. + """ + # get the number of slice for the fused qkv + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_fused) + + # split the fused qkv + # from + # [Q, K, V] + # to + # [Q1, Q2, K1, K2, V1, V2] + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + + # rearrange the slice into the final order + # from + # [Q1, Q2, K1, K2, V1, V2] + # to + # [Q1, K1, V1], [Q2, K2, V2] + weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]] + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1) + return weight_of_current_rank + + +def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup): + """ + The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. + """ + world_size = dist.get_world_size(group=process_group) + + # gather the tensors + # from + # [Q1, K1, V1], [Q2, K2, V2] + # to + # [Q1, K1, V1, Q2, K2, V2] + origin_device = qkv.device + qkv = qkv.cuda() + gather_list = [torch.zeros_like(qkv) for _ in range(world_size)] + dist.all_gather(gather_list, qkv, group=process_group) + gather_weight = torch.cat(gather_list, dim=-1) + gather_weight = gather_weight.to(origin_device) + qkv = qkv.to(origin_device) + + # rearrange the tensor slices + # from + # [Q1, K1, V1, Q2, K2, V2] + # to + # [Q1, Q2, K1, K2, V1, V2] + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + reordered_chunk_list = [] + for i in range(n_fused): + reordered_chunk_list.extend(weight_chunks[i::n_fused]) + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) + return reordered_gather_weight + + class LinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. Args: in_features (int): size of each input sample. @@ -41,6 +104,7 @@ class LinearConv1D_Col(ParallelModule): bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None. + n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output @@ -63,8 +127,10 @@ class LinearConv1D_Col(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + async_communication: bool = False, gather_output: bool = False, skip_bias_add: bool = False, + n_fused: int = 3, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -75,23 +141,34 @@ class LinearConv1D_Col(ParallelModule): self.gather_output = gather_output self.skip_bias_add = skip_bias_add self.device = device + self.n_fused = n_fused self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) + self.async_communication = async_communication if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - self.out_features_per_partition = divide(out_features, self.num_partitions) - # Parameters. # Initialize weight. - if device is None: - device = get_current_device() factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + + def shard_fn(tensor): + return split_fused_qkv(tensor, self.n_fused, self.process_group) + + def gather_fn(tensor): + return gather_fused_qkv(tensor, 3, self.process_group) + + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) + self.weight = customized_distributed_tensor_to_param(sharded_weight) if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + bias = torch.empty(self.out_features, **factory_kwargs) + + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) + self.bias = customized_distributed_tensor_to_param(sharded_bias) else: self.bias = None @@ -103,7 +180,7 @@ class LinearConv1D_Col(ParallelModule): self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. @@ -135,29 +212,12 @@ class LinearConv1D_Col(ParallelModule): # TODO: copy the sharded weights with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on row is equal to shard on column - - # first rearange the order of weight and bias - world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_fused) - new_order = [] - for i in range(world_size): - new_order.append(order[i::world_size]) - new_order = torch.cat(new_order) - - weight_chunks = torch.chunk(module.weight.data, world_size * n_fused, dim=1) - rearanged_weight_chunks = [weight_chunks[i] for i in new_order] - rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1) - sharded_weight = shard_colwise(rearanged_weight, process_group) - linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group) + linear_1d.weight.data.copy_(sharded_weight.data) if bias: - bias_chunks = torch.chunk(module.bias.data, world_size * n_fused, dim=0) - rearanged_bias_chunks = [bias_chunks[i] for i in new_order] - rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) - sharded_bias = shard_colwise(rearanged_bias, process_group) - linear_1d.bias.copy_(sharded_bias.contiguous()) + sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group) + linear_1d.bias.data.copy_(sharded_bias.data) return linear_1d @@ -169,15 +229,18 @@ class LinearConv1D_Col(ParallelModule): bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ + assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ + input_parallel = reduce_backward(input_, self.process_group) + # input_parallel = input_ + # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, + self.async_communication) if self.gather_output: # All-gather across the partitions. @@ -192,7 +255,8 @@ class LinearConv1D_Col(ParallelModule): class LinearConv1D_Row(ParallelModule): - r""" Linear layer with row parallelism + r""" Linear layer with row parallelism. + This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. Args: in_features (int): size of each input sample. @@ -243,11 +307,10 @@ class LinearConv1D_Row(ParallelModule): # Parameters. # Initialize weight. - if device is None: - device = get_current_device() - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + sharded_weight = shard_rowwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) if self.stream_chunk_num > 1: # TODO() work for inference only @@ -295,7 +358,7 @@ class LinearConv1D_Row(ParallelModule): # the weigh to the linear layer is a transpose # thus shard on col is equal to shard on row sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + linear_1d.weight.data.copy_(sharded_weight.data) if bias: linear_1d.bias.copy_(module.bias.data) @@ -325,12 +388,12 @@ class LinearConv1D_Row(ParallelModule): def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ + assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) input_ = input_ else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) @@ -342,7 +405,7 @@ class LinearConv1D_Row(ParallelModule): output_parallel_list = [None for i in range(self.stream_chunk_num)] handle_list = [] for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + output_parallel_list[i] = torch.matmul(input_, self.weight_list[i]) handle = torch.distributed.all_reduce(output_parallel_list[i], group=self.process_group, async_op=True) @@ -352,9 +415,8 @@ class LinearConv1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, self.process_group) + output_parallel = torch.matmul(input_, self.weight) + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 5edcb9dde..bda147b12 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -12,11 +12,14 @@ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module from colossalai.tensor.d_tensor import ( distribute_tensor, + distribute_tensor_with_customization, get_device_mesh, get_sharding_spec, + is_customized_distributed_tensor, is_distributed_tensor, sharded_tensor_to_param, to_global, + to_global_for_customized_distributed_tensor, ) __all__ = ['ParallelModule'] @@ -54,9 +57,10 @@ class ParallelModule(nn.Module, ABC): for name, param in self._parameters.items(): if param is not None: param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): destination[prefix + name] = to_global(param_) + elif is_customized_distributed_tensor(param_): + destination[prefix + name] = to_global_for_customized_distributed_tensor(param_) else: destination[prefix + name] = param_ @@ -124,6 +128,8 @@ class ParallelModule(nn.Module, ABC): sharding_spec = get_sharding_spec(param) sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) input_param = sharded_tensor_to_param(sharded_tensor) + elif is_customized_distributed_tensor(param): + input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn) # This is used to avoid copying uninitialized parameters into # non-lazy modules, since they dont have the hook to do the checks diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index 52eae0e14..3ae38a125 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -1,10 +1,13 @@ from .api import ( compute_global_numel, + customized_distributed_tensor_to_param, distribute_tensor, + distribute_tensor_with_customization, get_device_mesh, get_global_shape, get_layout, get_sharding_spec, + is_customized_distributed_tensor, is_distributed_tensor, is_sharded, redistribute, @@ -12,6 +15,7 @@ from .api import ( shard_rowwise, sharded_tensor_to_param, to_global, + to_global_for_customized_distributed_tensor, ) from .layout import Layout from .sharding_spec import ShardingSpec @@ -19,6 +23,6 @@ from .sharding_spec import ShardingSpec __all__ = [ 'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise', 'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh', - 'redistribute', 'get_layout' - 'Layout', 'ShardingSpec' + 'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization', + 'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec' ] diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index a38e5e6b7..95a44e09e 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -305,3 +305,130 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: """ assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' return dtensor.dist_layout.sharding_spec + + +# ====================================================== +# Some sharding does not obey the SPMD style +# e.g. Fused QKV layer in GPT2 +# we support customize sharding with the following APIs +# ====================================================== +def is_customized_distributed_tensor(tensor: torch.Tensor): + """ + Check whether the given tensor is a customized distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a customized distributed tensor. + """ + return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn') + + +def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + dtensor._old_detach = dtensor.detach + dtensor._old_clone = dtensor.clone + + def new_detach(self): + t_ = self._old_detach() + t_.shard_fn = self.shard_fn + t_.gather_fn = self.gather_fn + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._old_clone(*args, **kwargs) + t_.shard_fn = self.shard_fn + t_.gather_fn = self.gather_fn + return t_ + + # bind the new methods to the tensor + dtensor.detach = new_detach.__get__(dtensor) + dtensor.clone = new_clone.__get__(dtensor) + return dtensor + + +def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable): + """ + Distribute the given tensor with the given shard_fn and gather_fn. + + Example: + + ```python + # define shard and gather functions + def shard_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + return tensor.chunk(world_size, dim=0)[rank] + + def gather_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + shard_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(shard_list, tensor) + return torch.cat(shard_list, dim=0) + + # create a distributed tensor + tensor = torch.rand(4, 4) + dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn) + ``` + + Args: + tensor (torch.Tensor): The tensor to be distributed. + shard_fn (callable): The function to shard the tensor. + gather_fn (callable): The function to gather the tensor. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert callable(shard_fn), 'The shard_fn must be callable.' + assert callable(gather_fn), 'The gather_fn must be callable.' + assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + + sharded_tensor = shard_fn(tensor) + + # set the shard_fn and gather_fn as attributes of the distributed tensor + sharded_tensor.shard_fn = shard_fn + sharded_tensor.gather_fn = gather_fn + + # set the shard_fn and gather_fn as attributes of the distributed tensor + _hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor) + + return sharded_tensor + + +def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: + """ + Gather the given tensor to the global tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + torch.Tensor: The global tensor. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + return dtensor.gather_fn(dtensor) + + +def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): + """ + Convert the given customized distributed tensor to a parameter. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + + param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) + + # make it distributed as well + param.shard_fn = dtensor.shard_fn + param.gather_fn = dtensor.gather_fn + _hijack_detach_and_clone_for_customized_distributed_tensor(param) + return param diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index a2b8bf22c..da3bdc1d7 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -27,8 +27,13 @@ def check_linear_1d_col(): # check computation correctness x = torch.rand(4, 32).cuda() - out = linear(x) - gather_out = linear_col(x) + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + out = linear(x_for_unshard) + gather_out = linear_col(x_for_shard) assert_close(out, gather_out) # check backward correctness @@ -39,6 +44,11 @@ def check_linear_1d_col(): target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] assert_close(target_grad, linear_col.weight.grad) + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + def check_linear_1d_row(): linear = nn.Linear(32, 128).cuda() @@ -49,8 +59,14 @@ def check_linear_1d_row(): # check computation correctness x = torch.rand(4, 32).cuda() - out = linear(x) - gather_out = linear_row(x) + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + out = linear(x_for_unshard) + gather_out = linear_row(x_for_shard) assert_close(out, gather_out) # check backward correctness @@ -61,11 +77,49 @@ def check_linear_1d_row(): target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] assert_close(target_grad, linear_row.weight.grad) + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + +def check_linear_col_plus_row(): + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() + linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) + linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) + + # check computation correctness + x = torch.rand(4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + unshard_out = linear_2(linear_1(x_for_unshard)) + shard_out = linear_row(linear_col(x_for_shard)) + assert_close(unshard_out, shard_out) + + # check backward correctness + unshard_out.sum().backward() + shard_out.sum().backward() + + rank = dist.get_rank() + target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank] + assert_close(target_1_grad, linear_col.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_linear_1d_col() - # check_linear_1d_row() + check_linear_1d_row() + check_linear_col_plus_row() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_linearconv_1d.py b/tests/test_shardformer/test_layer/test_linearconv_1d.py index e0c97178d..efdb88351 100644 --- a/tests/test_shardformer/test_layer/test_linearconv_1d.py +++ b/tests/test_shardformer/test_layer/test_linearconv_1d.py @@ -5,6 +5,7 @@ from torch.testing import assert_close import colossalai from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row +from colossalai.shardformer.layer.linear_conv import split_fused_qkv from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -53,9 +54,15 @@ def check_linear_conv_1d_col(): linear = Conv1D(192, 48).cuda() linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3) - assert linear_conv_col.weight.shape == torch.Size([96, 48]) + assert linear.weight.shape == torch.Size([48, 192]) + assert linear.bias.shape == torch.Size([192]) + assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_conv_col.bias.shape == torch.Size([96]) + # ensure weights are reversibly loadable + linear_conv_col.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_conv_col.state_dict()) + # check computation correctness x = torch.rand(4, 48).cuda() out = linear(x) @@ -66,16 +73,16 @@ def check_linear_conv_1d_col(): out.sum().backward() gather_out.sum().backward() - rank = dist.get_rank() - target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] - assert_close(target_grad.transpose(0, 1).contiguous(), linear_conv_col.weight.grad) + target_grad = split_fused_qkv(linear.weight.grad, 3, None) + assert_close(target_grad, linear_conv_col.weight.grad) def check_linear_1d_row(): linear = Conv1D(192, 48).cuda() linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) - assert linear_row.weight.shape == torch.Size([192, 24]) + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.bias.shape == torch.Size([192]) # check computation correctness @@ -89,13 +96,14 @@ def check_linear_1d_row(): gather_out.sum().backward() rank = dist.get_rank() - target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] assert_close(target_grad, linear_row.weight.grad) def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_linear_conv_1d_col() + check_linear_1d_row() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 9aa02ec34..676267c2c 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -20,20 +20,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check grad equality if org_model.__class__.__name__ == 'GPT2Model': - org_grad = org_model.h[0].attn.c_attn.weight.grad - shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous() + org_grad = org_model.h[0].mlp.c_fc.weight.grad + shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad else: org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad - shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous() + shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=1) assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose( + org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" def check_gpt2(rank, world_size, port):