Browse Source

fix parallel_input flag for Linear1D_Col gather_output

pull/582/head
Wesley 3 years ago committed by アマデウス
parent
commit
666cfd094a
  1. 5
      colossalai/nn/layer/parallel_1d/layers.py

5
colossalai/nn/layer/parallel_1d/layers.py

@ -302,7 +302,10 @@ class Linear1D_Col(ParallelLayer):
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(True)
if self.gather_output:
set_parallel_input(False)
else:
set_parallel_input(True)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features

Loading…
Cancel
Save