mirror of https://github.com/hpcaitech/ColossalAI
add gather_output for VocabParallelClassifier1D (#1569)
parent
e615cfc3a8
commit
10dd8226b1
|
@ -283,11 +283,13 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
self.gather_output = gather_output
|
||||
self.parallel_input = get_parallel_input()
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
|
@ -382,7 +384,12 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
# Matrix multiply.
|
||||
output = F.linear(input_parallel, self.weight, self.bias)
|
||||
output_parallel = F.linear(input_parallel, self.weight, self.bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
else:
|
||||
output = output_parallel
|
||||
return output
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue