mirror of https://github.com/hpcaitech/ColossalAI
fix the merge
parent
2eb36839c6
commit
88b3f0698c
|
@ -202,21 +202,21 @@ class Linear1D_Col(ParallelModule):
|
|||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
|
|
Loading…
Reference in New Issue