From dc2cdaf3e8bc865ae5b8d653230876bba8dbf787 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 11 Oct 2024 13:44:40 +0800 Subject: [PATCH] [shardformer] optimize seq parallelism (#6086) * [shardformer] optimize seq parallelism * [shardformer] fix gpt2 fused linear col * [plugin] update gemini plugin * [plugin] update moe hybrid plugin * [test] update gpt2 fused linear test * [shardformer] fix gpt2 fused linear reduce --- colossalai/booster/plugin/gemini_plugin.py | 4 - .../booster/plugin/hybrid_parallel_plugin.py | 3 - .../plugin/moe_hybrid_parallel_plugin.py | 3 - colossalai/shardformer/layer/_operation.py | 243 ++++++------------ colossalai/shardformer/layer/linear.py | 34 +-- .../shardformer/layer/qkv_fused_linear.py | 66 ++--- colossalai/shardformer/policies/bert.py | 5 - colossalai/shardformer/policies/bloom.py | 3 - colossalai/shardformer/policies/chatglm2.py | 2 - colossalai/shardformer/policies/gpt2.py | 3 - colossalai/shardformer/policies/gptj.py | 4 - colossalai/shardformer/shard/shard_config.py | 15 -- .../test_gpt2_qkv_fused_linear_1d.py | 8 +- 13 files changed, 113 insertions(+), 280 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ae49aa8b1..4c8258113 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase): enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. - enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. @@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase): enable_flash_attention: bool = False, enable_sequence_parallelism: bool = False, enable_jit_fused: bool = False, - enable_sequence_overlap: bool = False, enable_async_reduce: bool = True, use_fp8: bool = False, verbose: bool = False, @@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase): self.enable_flash_attention = enable_flash_attention self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False self.enable_jit_fused = enable_jit_fused - self.enable_sequence_overlap = enable_sequence_overlap self.verbose = verbose self.tp_size = tp_size @@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase): enable_flash_attention=self.enable_flash_attention, enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=self.enable_sequence_parallelism, - enable_sequence_overlap=self.enable_sequence_overlap, ) def __del__(self): diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bb663f6a6..0674451a4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -951,7 +951,6 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". - enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. microbatch_size (int, optional): Microbatch size when using pipeline parallelism. @@ -1002,7 +1001,6 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, sequence_parallelism_mode: str = None, - enable_sequence_overlap: bool = False, parallel_output: bool = True, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, @@ -1174,7 +1172,6 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, sequence_parallelism_mode=sequence_parallelism_mode, - enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 0807b3749..2f08d2183 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -140,7 +140,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". - enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. microbatch_size (int, optional): Microbatch size when using pipeline parallelism. @@ -189,7 +188,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, sequence_parallelism_mode: str = None, - enable_sequence_overlap: bool = False, parallel_output: bool = True, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, @@ -351,7 +349,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, sequence_parallelism_mode=sequence_parallelism_mode, - enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 1d7a1f104..5499443b6 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -102,7 +102,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - if ctx.async_grad_allreduce and fp8_communication: + if fp8_communication or not ctx.async_grad_allreduce: _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2") elif ctx.async_grad_allreduce: # Asynchronous all-reduce @@ -216,10 +216,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group= for k in recv_tensors: send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k] + input_tensors = [] output_tensors = [] handles = communicate_step() # first round: special case, retrive from local tensor + input_tensors.append(input_to_gather) output_tensors.append(func(**input_to_gather, **input_local)) for i in range(group_size - 2): for handle in handles: @@ -230,14 +232,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group= handles = communicate_step() # actual computation + input_tensors.append(send_tensors) output_tensors.append(func(**send_tensors, **input_local)) # final round: special case, no need to send/recv again for handle in handles: handle.wait() + input_tensors.append(send_tensors) output_tensors.append(func(**recv_tensors, **input_local)) - return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) + gathered_input = {} + for k in input_to_gather: + input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]] + gathered_input[k] = torch.cat(input_shards, dim=gather_dim) + + gathered_output = torch.cat( + output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim + ) + + return gathered_output, gathered_input class _GatherForwardReduceScatterBackward(torch.autograd.Function): @@ -293,29 +306,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim - ctx.overlap = overlap if ring is True: input_to_gather = {"input": input_} input_local = {"weight": weight} - output = _ring_as_gather( + output, input_dict = _ring_as_gather( F.linear, input_to_gather=input_to_gather, input_local=input_local, process_group=process_group, ) + ctx.gathered_input = input_dict["input"] if bias is not None: output += bias else: input_parallel = _gather(input_, dim, process_group) + ctx.gathered_input = input_parallel if bias is not None: output = F.linear(input_parallel, weight, bias) else: @@ -329,100 +343,50 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group - overlap = ctx.overlap # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm if use_bias: bias = bias.view(bias.shape) - if not overlap: - input_parallel = _gather(input_, dim, process_group) - - total_input = input_parallel - grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - total_input = total_input.view(-1, total_input.shape[-1]) - - if ctx.async_grad_reduce_scatter: - # Asynchronous reduce-scatter - input_list = [ - item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty( - input_.shape, dtype=input_parallel.dtype, device=input_parallel.device - ).contiguous() - handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have - # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - - if _grad_accum_fusion_available and weight.grad is not None: - grad = weight.grad - if grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) - grad_weight = None - elif grad.dtype == torch.float16: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) - grad_weight = None - else: - grad_weight = grad_output.t().matmul(total_input) - else: - grad_weight = grad_output.t().matmul(total_input) - - grad_bias = grad_output.sum(dim=0) if use_bias else None + input_parallel = ctx.gathered_input - if ctx.async_grad_reduce_scatter: - handle.wait() + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) - else: - input_ = input_.contiguous() - world_size = dist.get_world_size(process_group) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - - # do all gather in is async way - gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) - # calculate gradient and prepare data asynchronously with all-gather - # calculate - grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - grad_bias = grad_output.sum(dim=0) if use_bias else None - # prepare data + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() - # wait until all-gather finished - gather_handle.wait() - - # do reduce-scatter in async way - reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - input_parallel = torch.cat(tensor_list, dim=dim).contiguous() - # calculate gradient - if len(input_parallel.shape) > 2: - input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) - - if _grad_accum_fusion_available and weight.grad is not None: - grad = weight.grad - if grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad) - grad_weight = None - elif grad.dtype == torch.float16: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad) - grad_weight = None - else: - grad_weight = grad_output.t().matmul(input_parallel) + output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None else: - grad_weight = grad_output.t().matmul(input_parallel) - # grad_weight = grad_output.t().matmul(input_parallel) - # wait until reduce-scatter finished - reducescatter_handle.wait() + grad_weight = grad_output.t().matmul(total_input) + else: + grad_weight = grad_output.t().matmul(total_input) - return output, grad_weight, grad_bias, None, None, None, None, None + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + return output, grad_weight, grad_bias, None, None, None, None def _ring_as_reducescatter( @@ -553,7 +517,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): # Convert the tensor shapes to 2D for execution compatibility if len(grad_output.shape) > 2: grad_output = grad_output.view(-1, grad_output.shape[-1]) - total_input = total_input.view(-1, total_input.shape[-1]) + total_input = total_input.reshape(-1, total_input.shape[-1]) grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -611,34 +575,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward( - ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication - ): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim - ctx.overlap = overlap ctx.fp8_communication = fp8_communication if ring is True: - input_to_gather = {} - input_local = {} - input_to_gather["input"] = input_ - input_local["other"] = weight + input_to_gather = {"input": input_} + input_local = {"other": weight} - output = _ring_as_gather( + output, input_dict = _ring_as_gather( torch.matmul, input_to_gather=input_to_gather, input_local=input_local, process_group=process_group, gather_dim=dim, ) + ctx.gathered_input = input_dict["input"] else: input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") - + ctx.gathered_input = input_parallel output = torch.matmul(input_parallel, weight) if bias is not None: @@ -651,76 +611,39 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group - overlap = ctx.overlap - fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm weight = weight.view(weight.shape) if use_bias: bias = bias.view(bias.shape) - if not overlap: - input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2") - - total_input = input_parallel - grad_input = grad_output.matmul(weight.T) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - total_input = total_input.view(-1, total_input.shape[-1]) - - if ctx.async_grad_reduce_scatter: - # Asynchronous reduce-scatter - input_list = [ - item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty( - input_.shape, dtype=input_parallel.dtype, device=input_parallel.device - ).contiguous() - handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have - # all-reduce scheduled first and have GPU resources allocated - - grad_weight = total_input.t().matmul(grad_output) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - if ctx.async_grad_reduce_scatter: - handle.wait() + input_parallel = ctx.gathered_input - else: - world_size = dist.get_world_size(process_group) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - - # do all gather in is async way - gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) - # calculate gradient and prepare data asynchronously with all-gather - # calculate - grad_input = grad_output.matmul(weight.T) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - grad_bias = grad_output.sum(dim=0) if use_bias else None - # prepare data + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() - # wait until all-gather finished - gather_handle.wait() + output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None - # do reduce-scatter in async way - reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - input_parallel = torch.cat(tensor_list, dim=dim).contiguous() - # calculate gradient - if len(input_parallel.shape) > 2: - input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) - grad_weight = input_parallel.t().matmul(grad_output) - # wait until reduce-scatter finished - reducescatter_handle.wait() + if ctx.async_grad_reduce_scatter: + handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -1050,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre def linear_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False ): return _LinearWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring ) @@ -1070,10 +993,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication ) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 52b0e79c6..0ba23c296 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -23,17 +23,15 @@ from colossalai.tensor.d_tensor.api import ( ) from ._operation import ( - gather_forward_reducescatter_backward, gather_forward_split_backward, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, - reducescatter_forward_gather_backward, split_forward_gather_backward, ) from .parallel_module import PaddingParallelModule, ParallelModule -from .utils import create_randomizer_with_offset +from .utils import create_randomizer_with_offset, is_share_sp_tp __all__ = ["Linear1D_Col", "Linear1D_Row"] @@ -55,7 +53,6 @@ class Linear1D_Col(ParallelModule): to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. - overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (`typing.Callable`): @@ -78,7 +75,6 @@ class Linear1D_Col(ParallelModule): gather_output: bool = False, seq_parallel_mode: str = None, seq_parallel_dim: int = 1, - overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -95,7 +91,6 @@ class Linear1D_Col(ParallelModule): self.gather_output = gather_output self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim - self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group @@ -202,16 +197,15 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel_mode == "split_gather": - input_parallel = gather_forward_reducescatter_backward( - input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication - ) - output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication - ) - elif self.seq_parallel_mode == "ring": + if is_share_sp_tp(self.seq_parallel_mode): output_parallel = linear_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True + input_parallel, + self.weight, + bias, + self.process_group, + True, + self.seq_parallel_dim, + ring=self.seq_parallel_mode == "ring", ) else: output_parallel = linear_with_async_comm( @@ -428,18 +422,13 @@ class Linear1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - if self.seq_parallel_mode == "split_gather": - output_parallel = F.linear(input_, self.weight) - output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication - ) - elif self.seq_parallel_mode == "ring": + if is_share_sp_tp(self.seq_parallel_mode): output = linear_reducescatter_forward_gather_backward( input_, self.weight, process_group=self.process_group, dim=self.seq_parallel_dim, - ring=True, + ring=self.seq_parallel_mode == "ring", ) else: output_parallel = F.linear(input_, self.weight) @@ -551,7 +540,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. - overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (`typing.Callable`): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index a1e25ff3a..6e469686b 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,19 +25,17 @@ from colossalai.tensor.d_tensor.api import ( ) from ._operation import ( - gather_forward_reducescatter_backward, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, - reduce_backward, reduce_forward, reducescatter_forward_gather_backward, split_forward_gather_backward, ) from .parallel_module import ParallelModule -from .utils import create_randomizer_with_offset +from .utils import create_randomizer_with_offset, is_share_sp_tp __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"] @@ -222,10 +220,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - async_communication: bool = False, gather_output: bool = False, seq_parallel_mode: str = None, - overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -240,12 +236,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.out_features = out_features self.gather_output = gather_output self.seq_parallel_mode = seq_parallel_mode - self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.split_sizes = split_sizes self.process_group = process_group - self.async_communication = async_communication self.fp8_communication = fp8_communication assert ( @@ -370,7 +364,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel_mode == "split_gather": + if is_share_sp_tp(self.seq_parallel_mode): input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, @@ -379,31 +373,18 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.process_group, True, 1, - self.overlap, - fp8_communication=self.fp8_communication, - ) - elif self.seq_parallel_mode == "ring": - input_parallel = input_ - output_parallel = matmul_gather_forward_reducescatter_backward( - input_parallel, - self.weight, - bias, - self.process_group, - True, - 1, - self.overlap, - True, + ring=self.seq_parallel_mode == "ring", fp8_communication=self.fp8_communication, ) elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) + input_parallel = input_ output_parallel = matmul_with_async_comm( input_parallel, self.weight, bias, self.process_group, - self.async_communication, + True, fp8_communication=self.fp8_communication, ) else: @@ -620,7 +601,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) - elif self.seq_parallel_mode == "split_gather": + elif is_share_sp_tp(self.seq_parallel_mode): output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward( output_parallel, @@ -628,13 +609,6 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): 1, self.fp8_communication, ) - elif self.seq_parallel_mode == "ring": - output_parallel = torch.matmul(input_, self.weight) - output = reducescatter_forward_gather_backward( - output_parallel, - self.process_group, - 1, - ) else: raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") @@ -691,7 +665,6 @@ class FusedLinear1D_Col(ParallelModule): gather_output: bool = False, seq_parallel_mode: str = None, seq_parallel_dim: int = 1, - overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -706,7 +679,6 @@ class FusedLinear1D_Col(ParallelModule): self.gather_output = gather_output self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim - self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.split_sizes = split_sizes @@ -830,16 +802,15 @@ class FusedLinear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel_mode == "split_gather": - input_parallel = gather_forward_reducescatter_backward( - input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication - ) - output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication - ) - elif self.seq_parallel_mode == "ring": + if is_share_sp_tp(self.seq_parallel_mode): output_parallel = linear_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True + input_parallel, + self.weight, + bias, + self.process_group, + True, + self.seq_parallel_dim, + ring=self.seq_parallel_mode == "ring", ) else: output_parallel = linear_with_async_comm( @@ -1031,18 +1002,13 @@ class FusedLinear1D_Row(ParallelModule): ) input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group) - if self.seq_parallel_mode == "split_gather": - output_parallel = F.linear(input_, self.weight) - output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication - ) - elif self.seq_parallel_mode == "ring": + if is_share_sp_tp(self.seq_parallel_mode): output = linear_reducescatter_forward_gather_backward( input_, self.weight, process_group=self.process_group, dim=self.seq_parallel_dim, - ring=True, + ring=self.seq_parallel_mode == "ring", ) else: output_parallel = F.linear(input_, self.weight) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 4c33e14bc..09673d396 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -73,7 +73,6 @@ class BertPolicy(Policy): ) sp_mode = "split_gather" - overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: @@ -97,7 +96,6 @@ class BertPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "seq_parallel_mode": sp_mode, - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), @@ -106,7 +104,6 @@ class BertPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "seq_parallel_mode": sp_mode, - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), @@ -115,7 +112,6 @@ class BertPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "seq_parallel_mode": sp_mode, - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), @@ -140,7 +136,6 @@ class BertPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "seq_parallel_mode": sp_mode, - "overlap": overlap, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, }, diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index a43ac02d0..7c6259e85 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -57,7 +57,6 @@ class BloomPolicy(Policy): ) sp_mode = "split_gather" - overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: @@ -78,7 +77,6 @@ class BloomPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "seq_parallel_mode": sp_mode, - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), @@ -99,7 +97,6 @@ class BloomPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "seq_parallel_mode": sp_mode, - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 1b7d2db85..c003570a0 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -67,7 +67,6 @@ class ChatGLMPolicy(Policy): f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather" ) sp_mode = "split_gather" - overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["split_gather"] if sp_mode == "all_to_all": @@ -127,7 +126,6 @@ class ChatGLMPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index faacf91b2..08accaaea 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -65,7 +65,6 @@ class GPT2Policy(Policy): f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" ) self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather" - overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention if self.shard_config.enable_tensor_parallelism: @@ -94,7 +93,6 @@ class GPT2Policy(Policy): kwargs={ "split_sizes": [self.model.config.hidden_size] * 3, "seq_parallel_mode": sp_mode, - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), @@ -109,7 +107,6 @@ class GPT2Policy(Policy): kwargs={ "split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size], "seq_parallel_mode": sp_mode, - "overlap": overlap, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, }, diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 6f0c8803c..9fcca1385 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -51,7 +51,6 @@ class GPTJPolicy(Policy): self.shard_config.enable_sequence_parallelism = False warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") - overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -76,7 +75,6 @@ class GPTJPolicy(Policy): suffix="attn.k_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), @@ -84,7 +82,6 @@ class GPTJPolicy(Policy): suffix="attn.q_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), @@ -92,7 +89,6 @@ class GPTJPolicy(Policy): suffix="attn.v_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "overlap": overlap, "fp8_communication": self.shard_config.fp8_communication, }, ), diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 1219119bb..911226e5c 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -26,7 +26,6 @@ class ShardConfig: enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. - enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False. @@ -44,7 +43,6 @@ class ShardConfig: enable_jit_fused: bool = False enable_sequence_parallelism: bool = False sequence_parallelism_mode: str = None - enable_sequence_overlap: bool = False parallel_output: bool = True make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None @@ -84,24 +82,12 @@ class ShardConfig: assert ( self.enable_tensor_parallelism ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" - elif self.sequence_parallelism_mode in ["all_to_all"]: - # assert ( - # not self.enable_tensor_parallelism - # ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" - if self.enable_sequence_overlap: - self.enable_sequence_overlap = False - warnings.warn( - f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}" - ) else: if self.sequence_parallelism_mode: self.sequence_parallelism_mode = None warnings.warn( f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False" ) - assert ( - not self.enable_sequence_overlap - ), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True" # get the tensor parallel size if not self.enable_tensor_parallelism: @@ -134,4 +120,3 @@ class ShardConfig: # This can cause non-in-place param sharding when used without ZeRO. # It may also slow down training when seq len is small. Plz enable manually. # self.enable_sequence_parallelism = True - # self.enable_sequence_overlap = True diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 923075e0e..a45beb771 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -41,7 +41,7 @@ class Conv1D(nn.Module): return x -def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -52,7 +52,6 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b gather_output=True, seq_parallel_mode=seq_parallel_mode, split_sizes=[64] * 3, - overlap=overlap, ) assert linear.weight.shape == torch.Size([48, 192]) @@ -121,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool): @parameterize("lazy_init", [False, True]) @parameterize("seq_parallel_mode", ["split_gather", None]) -@parameterize("overlap", [True]) -def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): - check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel_mode) check_linear_conv_1d_row(lazy_init, seq_parallel_mode)