[TP] add assert for tp1d (#621)

pull/592/head^2
Ziyue Jiang 3 years ago committed by GitHub
parent 369a288bf3
commit 1c40ee8749
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer):
set_parallel_input(False)
env.vocab_parallel = False
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
@ -155,8 +156,14 @@ class Classifier1D(ParallelLayer):
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
@ -235,6 +242,9 @@ class VocabParallelClassifier1D(ParallelLayer):
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tensor:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
@ -305,6 +315,7 @@ class Linear1D_Col(ParallelLayer):
is_parallel_output = not self.gather_output
set_parallel_input(is_parallel_output)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
@ -318,6 +329,9 @@ class Linear1D_Col(ParallelLayer):
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
@ -329,6 +343,7 @@ class Linear1D_Col(ParallelLayer):
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
@ -393,6 +408,7 @@ class Linear1D_Row(ParallelLayer):
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
@ -407,8 +423,14 @@ class Linear1D_Row(ParallelLayer):
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)

Loading…
Cancel
Save