add gather_output for VocabParallelClassifier1D (#1569)

pull/1574/head
ver217 2 years ago committed by GitHub
parent e615cfc3a8
commit 10dd8226b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -283,11 +283,13 @@ class VocabParallelClassifier1D(ParallelLayer):
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.gather_output = gather_output
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
@ -382,7 +384,12 @@ class VocabParallelClassifier1D(ParallelLayer):
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
output = F.linear(input_parallel, self.weight, self.bias)
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
return output

Loading…
Cancel
Save