mirror of https://github.com/hpcaitech/ColossalAI
[TP] add assert for tp1d (#621)
parent
369a288bf3
commit
1c40ee8749
|
@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer):
|
|||
set_parallel_input(False)
|
||||
env.vocab_parallel = False
|
||||
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
if self.has_weight:
|
||||
|
@ -155,8 +156,14 @@ class Classifier1D(ParallelLayer):
|
|||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
|
@ -235,6 +242,9 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
# Matrix multiply.
|
||||
|
@ -305,6 +315,7 @@ class Linear1D_Col(ParallelLayer):
|
|||
is_parallel_output = not self.gather_output
|
||||
set_parallel_input(is_parallel_output)
|
||||
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
@ -318,6 +329,9 @@ class Linear1D_Col(ParallelLayer):
|
|||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
# Matrix multiply.
|
||||
|
@ -329,6 +343,7 @@ class Linear1D_Col(ParallelLayer):
|
|||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
|
@ -393,6 +408,7 @@ class Linear1D_Row(ParallelLayer):
|
|||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
@ -407,8 +423,14 @@ class Linear1D_Row(ParallelLayer):
|
|||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
|
|
Loading…
Reference in New Issue