[NFC] polish initializer_3d.py code style (#3279)

pull/3313/head
Kai Wang (Victor Kai) 2 years ago committed by binmakeswell
parent 94eec1c5ad
commit 964a28678f

@ -4,6 +4,7 @@
import math import math
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
@ -213,7 +214,8 @@ class Initializer_3D_InputxWeight(ProcessGroupInitializer):
for h in range(self.num_group): for h in range(self.num_group):
for k in range(self.depth): for k in range(self.depth):
ranks = [ ranks = [
h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth) h * self.depth**3 + i + self.depth * (j + self.depth * k)
for j in range(self.depth)
for i in range(self.depth) for i in range(self.depth)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
@ -266,7 +268,8 @@ class Initializer_3D_OutputxWeight(ProcessGroupInitializer):
for h in range(self.num_group): for h in range(self.num_group):
for j in range(self.depth): for j in range(self.depth):
ranks = [ ranks = [
h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth) h * self.depth**3 + i + self.depth * (j + self.depth * k)
for k in range(self.depth)
for i in range(self.depth) for i in range(self.depth)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)

Loading…
Cancel
Save