fix the merge

pull/6023/head
wangbluo 2024-08-19 10:11:27 +00:00
parent 2eb36839c6
commit 88b3f0698c
1 changed files with 7 additions and 7 deletions

View File

@ -202,21 +202,21 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "split_gather":
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
if self.gather_output:
# All-gather across the partitions.