From 666cfd094a767074ca921cf5cfd5fcbfd2ad327d Mon Sep 17 00:00:00 2001 From: Wesley Date: Thu, 31 Mar 2022 16:38:14 +0800 Subject: [PATCH] fix parallel_input flag for Linear1D_Col gather_output --- colossalai/nn/layer/parallel_1d/layers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 8e6a42394..c819cb9a8 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -302,7 +302,10 @@ class Linear1D_Col(ParallelLayer): with seed(ParallelMode.TENSOR): self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() - set_parallel_input(True) + if self.gather_output: + set_parallel_input(False) + else: + set_parallel_input(True) def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features