|
|
|
@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer):
|
|
|
|
|
set_parallel_input(False) |
|
|
|
|
env.vocab_parallel = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None: |
|
|
|
|
fan_in, fan_out = self.in_features, self.num_classes |
|
|
|
|
if self.has_weight: |
|
|
|
@ -155,8 +156,14 @@ class Classifier1D(ParallelLayer):
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
# Set up backprop all-reduce. |
|
|
|
|
if self.parallel_input: |
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \ |
|
|
|
|
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( |
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1]) |
|
|
|
|
input_ = input_ |
|
|
|
|
else: |
|
|
|
|
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ |
|
|
|
|
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( |
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) |
|
|
|
|
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) |
|
|
|
|
|
|
|
|
|
output_parallel = F.linear(input_, self.weight) |
|
|
|
@ -235,6 +242,9 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition) |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \ |
|
|
|
|
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( |
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1]) |
|
|
|
|
# Set up backprop all-reduce. |
|
|
|
|
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) |
|
|
|
|
# Matrix multiply. |
|
|
|
@ -305,6 +315,7 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
|
is_parallel_output = not self.gather_output |
|
|
|
|
set_parallel_input(is_parallel_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None: |
|
|
|
|
fan_in, fan_out = self.in_features, self.out_features |
|
|
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) |
|
|
|
@ -318,6 +329,9 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition) |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: |
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \ |
|
|
|
|
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( |
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1]) |
|
|
|
|
# Set up backprop all-reduce. |
|
|
|
|
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) |
|
|
|
|
# Matrix multiply. |
|
|
|
@ -329,6 +343,7 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) |
|
|
|
|
else: |
|
|
|
|
output = output_parallel |
|
|
|
|
|
|
|
|
|
if self.skip_bias_add: |
|
|
|
|
return output, self.bias |
|
|
|
|
else: |
|
|
|
@ -393,6 +408,7 @@ class Linear1D_Row(ParallelLayer):
|
|
|
|
|
self._set_tensor_parallel_attributes() |
|
|
|
|
set_parallel_input(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None: |
|
|
|
|
fan_in, fan_out = self.in_features, self.out_features |
|
|
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) |
|
|
|
@ -407,8 +423,14 @@ class Linear1D_Row(ParallelLayer):
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
# Set up backprop all-reduce. |
|
|
|
|
if self.parallel_input: |
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \ |
|
|
|
|
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( |
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1]) |
|
|
|
|
input_ = input_ |
|
|
|
|
else: |
|
|
|
|
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ |
|
|
|
|
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( |
|
|
|
|
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) |
|
|
|
|
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) |
|
|
|
|
|
|
|
|
|
output_parallel = F.linear(input_, self.weight) |
|
|
|
|