|
|
|
@ -176,7 +176,6 @@ class Classifier1D(ParallelLayer):
|
|
|
|
|
set_parallel_input(False) |
|
|
|
|
env.vocab_parallel = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None: |
|
|
|
|
fan_in, fan_out = self.in_features, self.num_classes |
|
|
|
|
if self.has_weight: |
|
|
|
@ -450,7 +449,6 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
|
is_parallel_output = not self.gather_output |
|
|
|
|
set_parallel_input(is_parallel_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None: |
|
|
|
|
fan_in, fan_out = self.in_features, self.out_features |
|
|
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) |
|
|
|
@ -589,7 +587,6 @@ class Linear1D_Row(ParallelLayer):
|
|
|
|
|
self._set_tensor_parallel_attributes() |
|
|
|
|
set_parallel_input(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None: |
|
|
|
|
fan_in, fan_out = self.in_features, self.out_features |
|
|
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) |
|
|
|
|