|
|
|
@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer):
|
|
|
|
|
set_parallel_input(False)
|
|
|
|
|
env.vocab_parallel = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
|
|
|
fan_in, fan_out = self.in_features, self.num_classes
|
|
|
|
|
if self.has_weight:
|
|
|
|
@ -155,8 +156,14 @@ class Classifier1D(ParallelLayer):
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
|
if self.parallel_input:
|
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
|
|
|
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1])
|
|
|
|
|
input_ = input_
|
|
|
|
|
else:
|
|
|
|
|
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
|
|
|
|
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
|
|
|
|
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
|
|
|
|
|
|
|
|
|
output_parallel = F.linear(input_, self.weight)
|
|
|
|
@ -235,6 +242,9 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
|
|
|
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1])
|
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
|
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
|
|
|
|
# Matrix multiply.
|
|
|
|
@ -305,6 +315,7 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
|
is_parallel_output = not self.gather_output
|
|
|
|
|
set_parallel_input(is_parallel_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
|
|
|
fan_in, fan_out = self.in_features, self.out_features
|
|
|
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
|
|
@ -318,6 +329,9 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
|
|
|
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1])
|
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
|
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
|
|
|
|
# Matrix multiply.
|
|
|
|
@ -329,6 +343,7 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
output = output_parallel
|
|
|
|
|
|
|
|
|
|
if self.skip_bias_add:
|
|
|
|
|
return output, self.bias
|
|
|
|
|
else:
|
|
|
|
@ -393,6 +408,7 @@ class Linear1D_Row(ParallelLayer):
|
|
|
|
|
self._set_tensor_parallel_attributes()
|
|
|
|
|
set_parallel_input(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
|
|
|
fan_in, fan_out = self.in_features, self.out_features
|
|
|
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
|
|
@ -407,8 +423,14 @@ class Linear1D_Row(ParallelLayer):
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
|
if self.parallel_input:
|
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
|
|
|
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1])
|
|
|
|
|
input_ = input_
|
|
|
|
|
else:
|
|
|
|
|
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
|
|
|
|
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
|
|
|
|
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
|
|
|
|
|
|
|
|
|
output_parallel = F.linear(input_, self.weight)
|
|
|
|
|