From c554b7f559b592c4d358db677c87658b11a6341c Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:16:40 +0800 Subject: [PATCH] =?UTF-8?q?[shardformer/fix=20overlap=20bug]=20fix=20overl?= =?UTF-8?q?ap=20bug,=20add=20overlap=20as=20an=20option=20in=20shardco?= =?UTF-8?q?=E2=80=A6=20(#4516)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom --- colossalai/shardformer/layer/_operation.py | 53 ++++++++----------- colossalai/shardformer/layer/linear.py | 2 +- colossalai/shardformer/policies/bert.py | 21 ++++++-- colossalai/shardformer/policies/bloom.py | 11 +++- colossalai/shardformer/policies/chatglm2.py | 4 +- colossalai/shardformer/shard/shard_config.py | 9 ++++ .../test_layer/test_linear_1d.py | 2 +- 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f1f48273c..55d9413b9 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -211,43 +211,36 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): handle.wait() else: - # create new stream for calculate the gradient - calculate_stream = torch.cuda.Stream() - - # do all gather in default stream input_ = input_.contiguous() world_size = dist.get_world_size(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) - - # calculate gradient in calculate_stream - with torch.cuda.stream(calculate_stream): - # calculate - grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - # prepare data - input_list = [ - item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() - torch.cuda.current_stream().wait_stream(calculate_stream) + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished gather_handle.wait() + # do reduce-scatter in async way reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - with torch.cuda.stream(calculate_stream): - input_parallel = torch.cat(tensor_list, dim=dim).contiguous() - if len(input_parallel.shape) > 2: - input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) - print(grad_output.shape, input_parallel.shape) - grad_weight = grad_output.t().matmul(input_parallel) - - torch.cuda.current_stream().wait_stream(calculate_stream) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = grad_output.t().matmul(input_parallel) + # wait until reduce-scatter finished reducescatter_handle.wait() return output, grad_weight, grad_bias, None, None, None, None diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 81c3f973f..111d51b3f 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -75,7 +75,7 @@ class Linear1D_Col(ParallelModule): gather_output: bool = False, seq_parallel: bool = False, seq_parallel_dim: int = 1, - overlap: bool = False, + overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 19dd95fd6..a141b7bd8 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -56,6 +56,7 @@ class BertPolicy(Policy): policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ "attention.self.all_head_size": @@ -71,17 +72,26 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -99,7 +109,10 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="output.dense", diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 21db13f6e..7c418d02b 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -45,6 +45,7 @@ class BloomPolicy(Policy): policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -55,7 +56,10 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={'seq_parallel': use_sequence_parallel}), + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'overlap': overlap + }), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, @@ -67,7 +71,10 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - kwargs={'seq_parallel': use_sequence_parallel}), + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'overlap': overlap + }), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index b0d684a67..5bcbc2acc 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -50,6 +50,7 @@ class ChatGLMPolicy(Policy): policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ @@ -81,7 +82,8 @@ class ChatGLMPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ 'seq_parallel': use_sequence_parallel, - 'seq_parallel_dim': 0 + 'seq_parallel_dim': 0, + 'overlap': overlap }), SubModuleReplacementDescription(suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 900f8475c..c5c3d185e 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -20,6 +20,8 @@ class ShardConfig: enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -29,6 +31,7 @@ class ShardConfig: enable_flash_attention: bool = False enable_jit_fused: bool = False enable_sequence_parallelism: bool = False + enable_sequence_overlap: bool = False # pipeline_parallel_size: int # data_parallel_size: int @@ -41,6 +44,11 @@ class ShardConfig: return self._tensor_parallel_size def __post_init__(self): + if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: + raise ValueError( + "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True") + if not self.enable_sequence_parallelism and self.enable_sequence_overlap: + raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") if not self.enable_tensor_parallelism: self._tensor_parallel_size = 1 else: @@ -59,3 +67,4 @@ class ShardConfig: self.enable_flash_attention = True self.enable_jit_fused = True self.enable_sequence_parallelism = True + self.enable_sequence_overlap = True diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 3ad8f14b9..e6d86d533 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -168,7 +168,7 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool @parameterize('lazy_init', [False, True]) @parameterize('seq_parallel', [False, True]) -@parameterize('overlap', [False, True]) +@parameterize('overlap', [True]) def run_dist_linear_test(lazy_init, seq_parallel, overlap): check_linear_1d_col(lazy_init, seq_parallel, overlap) check_linear_1d_row(lazy_init, seq_parallel)