|
|
|
@ -283,11 +283,13 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|
|
|
|
weight: Parameter = None,
|
|
|
|
|
bias: bool = True,
|
|
|
|
|
dtype: torch.dtype = None,
|
|
|
|
|
gather_output: bool = False,
|
|
|
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
|
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.in_features = in_features
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.gather_output = gather_output
|
|
|
|
|
self.parallel_input = get_parallel_input()
|
|
|
|
|
|
|
|
|
|
# Divide the weight matrix along the last dimension.
|
|
|
|
@ -382,7 +384,12 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
|
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
|
|
|
|
# Matrix multiply.
|
|
|
|
|
output = F.linear(input_parallel, self.weight, self.bias)
|
|
|
|
|
output_parallel = F.linear(input_parallel, self.weight, self.bias)
|
|
|
|
|
if self.gather_output:
|
|
|
|
|
# All-gather across the partitions.
|
|
|
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
output = output_parallel
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|