From 10dd8226b16a1303e20b25cba552103e196526f0 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 8 Sep 2022 16:40:56 +0800 Subject: [PATCH] add gather_output for VocabParallelClassifier1D (#1569) --- colossalai/nn/layer/parallel_1d/layers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 7b89c5e1f..fd26f67e8 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -283,11 +283,13 @@ class VocabParallelClassifier1D(ParallelLayer): weight: Parameter = None, bias: bool = True, dtype: torch.dtype = None, + gather_output: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() self.in_features = in_features self.num_classes = num_classes + self.gather_output = gather_output self.parallel_input = get_parallel_input() # Divide the weight matrix along the last dimension. @@ -382,7 +384,12 @@ class VocabParallelClassifier1D(ParallelLayer): # Set up backprop all-reduce. input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) # Matrix multiply. - output = F.linear(input_parallel, self.weight, self.bias) + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel return output