[]Corrected 3d vocab parallel embedding (#707)

pull/710/head
アマデウス 2022-04-11 10:17:55 +08:00 committed by GitHub
parent ee112fe1da
commit 3fc8a204dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -525,7 +525,7 @@ class VocabParallelClassifier3D(ParallelLayer):
def _set_tensor_parallel_attributes(self) -> None:
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
@ -1048,7 +1048,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):