diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index edfe07697..141d988f6 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer): set_parallel_input(False) env.vocab_parallel = False + def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.num_classes if self.has_weight: @@ -155,8 +156,14 @@ class Classifier1D(ParallelLayer): def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) input_ = input_ else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) output_parallel = F.linear(input_, self.weight) @@ -235,6 +242,9 @@ class VocabParallelClassifier1D(ParallelLayer): set_tensor_parallel_attribute_by_partition(self.bias, num_partition) def forward(self, input_: Tensor) -> Tensor: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) # Set up backprop all-reduce. input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) # Matrix multiply. @@ -305,6 +315,7 @@ class Linear1D_Col(ParallelLayer): is_parallel_output = not self.gather_output set_parallel_input(is_parallel_output) + def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) @@ -318,6 +329,9 @@ class Linear1D_Col(ParallelLayer): set_tensor_parallel_attribute_by_partition(self.bias, num_partition) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) # Set up backprop all-reduce. input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) # Matrix multiply. @@ -329,6 +343,7 @@ class Linear1D_Col(ParallelLayer): output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) else: output = output_parallel + if self.skip_bias_add: return output, self.bias else: @@ -393,6 +408,7 @@ class Linear1D_Row(ParallelLayer): self._set_tensor_parallel_attributes() set_parallel_input(False) + def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) @@ -407,8 +423,14 @@ class Linear1D_Row(ParallelLayer): def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) input_ = input_ else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) output_parallel = F.linear(input_, self.weight)