From e241b74f24ac4efe4712bcefedfd7f14f3dd7b37 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 29 Aug 2023 18:30:50 +0800 Subject: [PATCH] [shardformer] Add overlap support for gpt2 (#4535) * add overlap support for gpt2 * remove unused code * remove unused code --- colossalai/shardformer/layer/_operation.py | 87 ++++++++++++----- .../shardformer/layer/qkv_fused_linear.py | 4 +- .../shardformer/policies/base_policy.py | 19 ---- colossalai/shardformer/policies/gpt2.py | 94 ++++++++++--------- .../test_gpt2_qkv_fused_linear_1d.py | 10 +- 5 files changed, 120 insertions(+), 94 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 55d9413b9..f45ccc64b 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -291,12 +291,13 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim + ctx.overlap = overlap input_parallel = _gather(input_, dim, process_group) @@ -312,37 +313,70 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group + overlap = ctx.overlap - # TODO: overlap SP input with gradient computation - input_parallel = _gather(input_, dim, process_group) + if not overlap: + input_parallel = _gather(input_, dim, process_group) - total_input = input_parallel - grad_input = grad_output.matmul(weight.T) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - total_input = total_input.view(-1, total_input.shape[-1]) + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) - # TODO: overlap SP input with gradient computation - if ctx.async_grad_reduce_scatter: - # Asynchronous reduce-scatter + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() - handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # reduce-scatter scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() - grad_weight = total_input.t().matmul(grad_output) - grad_bias = grad_output.sum(dim=0) if use_bias else None + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = input_parallel.t().matmul(grad_output) + # wait until reduce-scatter finished + reducescatter_handle.wait() - if ctx.async_grad_reduce_scatter: - handle.wait() - - return output, grad_weight, grad_bias, None, None, None + return output, grad_weight, grad_bias, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -510,9 +544,10 @@ def linear_reducescatter_forward_gather_backward(input_, process_group, dim): return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) -def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): +def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_reduce_scatter, dim) + async_grad_reduce_scatter, dim, overlap) def gather_forward_split_backward(input_, dim, process_group): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index ccb2bf7ea..5ce77805f 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -177,6 +177,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): async_communication: bool = False, gather_output: bool = False, seq_parallel: bool = False, + overlap: bool = False, skip_bias_add: bool = False, n_fused: int = 3, weight: Optional[Parameter] = None, @@ -190,6 +191,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.out_features = out_features self.gather_output = gather_output self.seq_parallel = seq_parallel + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.n_fused = n_fused @@ -308,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.seq_parallel: input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1) + self.process_group, True, 1, self.overlap) else: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 7022a1cfd..961c6a525 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -226,22 +226,3 @@ class Policy(ABC): end_idx = num_layers_per_stage_accumulated[stage + 1] return [start_idx, end_idx] - - def append_seq_parallel_to_policy( - self, - suffix_list: List[str], - module_policy_description: ModulePolicyDescription, - ): - r""" - Append the sequence parallel policy to the policy for the given key. - - Args: - suffix_list (List[str]): the suffix list of the module to be parallelized - policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated - """ - - for sub_description in module_policy_description.sub_module_replacement: - if (sub_description.suffix in suffix_list): - if sub_description.kwargs is None: - sub_description.kwargs = {} - sub_description.kwargs["seq_parallel"] = True diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index acae26309..5093fd469 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -37,7 +37,8 @@ class GPT2Policy(Policy): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model policy = {} - + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( @@ -50,47 +51,54 @@ class GPT2Policy(Policy): ), ]) - policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ - "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GPT2Block] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, + ), + SubModuleReplacementDescription(suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, + ), + SubModuleReplacementDescription(suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -126,8 +134,6 @@ class GPT2Policy(Policy): if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} - suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] - self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) return policy diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index ae6a1dc90..4c0f884a7 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -53,7 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -62,7 +62,8 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): process_group=None, gather_output=True, seq_parallel=seq_parallel, - n_fused=3) + n_fused=3, + overlap=overlap) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -129,8 +130,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): @parameterize('lazy_init', [False, True]) @parameterize('seq_parallel', [False, True]) -def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool): - check_linear_conv_1d_col(lazy_init, seq_parallel) +@parameterize('overlap', [True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) check_linear_conv_1d_row(lazy_init, seq_parallel)