From 88b3f0698c976f11ca96940aa5895c24c4d1793e Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 19 Aug 2024 10:11:27 +0000 Subject: [PATCH] fix the merge --- colossalai/shardformer/layer/linear.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 085609c01..d77dd4965 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -202,21 +202,21 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication - ) - elif self.seq_parallel_mode == "split_gather": + if self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication ) - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication + ) elif self.seq_parallel_mode == "ring": output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) else: - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + ) if self.gather_output: # All-gather across the partitions.