|
|
|
@ -4,6 +4,7 @@
|
|
|
|
|
import math |
|
|
|
|
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
|
from colossalai.global_variables import tensor_parallel_env as env |
|
|
|
|
from colossalai.registry import DIST_GROUP_INITIALIZER |
|
|
|
|
|
|
|
|
@ -213,7 +214,8 @@ class Initializer_3D_InputxWeight(ProcessGroupInitializer):
|
|
|
|
|
for h in range(self.num_group): |
|
|
|
|
for k in range(self.depth): |
|
|
|
|
ranks = [ |
|
|
|
|
h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth) |
|
|
|
|
h * self.depth**3 + i + self.depth * (j + self.depth * k) |
|
|
|
|
for j in range(self.depth) |
|
|
|
|
for i in range(self.depth) |
|
|
|
|
] |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
@ -266,7 +268,8 @@ class Initializer_3D_OutputxWeight(ProcessGroupInitializer):
|
|
|
|
|
for h in range(self.num_group): |
|
|
|
|
for j in range(self.depth): |
|
|
|
|
ranks = [ |
|
|
|
|
h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth) |
|
|
|
|
h * self.depth**3 + i + self.depth * (j + self.depth * k) |
|
|
|
|
for k in range(self.depth) |
|
|
|
|
for i in range(self.depth) |
|
|
|
|
] |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|