|
|
|
@ -12,8 +12,7 @@ from ..parallel_mode import ParallelMode
|
|
|
|
|
from .process_group_initializer import ProcessGroupInitializer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_tesseract_env_var(tesseract_dim: int, |
|
|
|
|
tesseract_dep: int): |
|
|
|
|
def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): |
|
|
|
|
# check global variable for TESSERACT |
|
|
|
|
env_tesseract_dim = env.tesseract_dim |
|
|
|
|
env_tesseract_dep = env.tesseract_dep |
|
|
|
@ -42,10 +41,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
|
|
|
|
:type tesseract_dep: int |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
tesseract_dim: int, |
|
|
|
|
tesseract_dep: int, |
|
|
|
|
*args): |
|
|
|
|
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): |
|
|
|
|
super(Initializer_2p5D_ROW, self).__init__(*args) |
|
|
|
|
self.num_group = self.world_size // self.tensor_parallel_size |
|
|
|
|
self.tesseract_dep = tesseract_dep |
|
|
|
@ -68,8 +64,10 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
|
|
|
|
for h in range(self.num_group): |
|
|
|
|
for j in range(self.tesseract_dim): |
|
|
|
|
for k in range(self.tesseract_dep): |
|
|
|
|
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( |
|
|
|
|
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)] |
|
|
|
|
ranks = [ |
|
|
|
|
h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) |
|
|
|
|
for i in range(self.tesseract_dim) |
|
|
|
|
] |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
|
|
|
|
|
if self.rank in ranks: |
|
|
|
@ -92,10 +90,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
|
|
|
|
|
:type tesseract_dep: int |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
tesseract_dim: int, |
|
|
|
|
tesseract_dep: int, |
|
|
|
|
*args): |
|
|
|
|
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): |
|
|
|
|
super(Initializer_2p5D_Col, self).__init__(*args) |
|
|
|
|
self.num_group = self.world_size // self.tensor_parallel_size |
|
|
|
|
self.tesseract_dep = tesseract_dep |
|
|
|
@ -118,8 +113,10 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
|
|
|
|
|
for h in range(self.num_group): |
|
|
|
|
for i in range(self.tesseract_dim): |
|
|
|
|
for k in range(self.tesseract_dep): |
|
|
|
|
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( |
|
|
|
|
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)] |
|
|
|
|
ranks = [ |
|
|
|
|
h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) |
|
|
|
|
for j in range(self.tesseract_dim) |
|
|
|
|
] |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
|
|
|
|
|
if self.rank in ranks: |
|
|
|
@ -142,10 +139,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
|
|
|
|
:type tesseract_dep: int |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
tesseract_dim: int, |
|
|
|
|
tesseract_dep: int, |
|
|
|
|
*args): |
|
|
|
|
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): |
|
|
|
|
super(Initializer_2p5D_Dep, self).__init__(*args) |
|
|
|
|
self.num_group = self.world_size // self.tensor_parallel_size |
|
|
|
|
self.tesseract_dep = tesseract_dep |
|
|
|
@ -168,8 +162,10 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
|
|
|
|
for h in range(self.num_group): |
|
|
|
|
for i in range(self.tesseract_dim): |
|
|
|
|
for j in range(self.tesseract_dim): |
|
|
|
|
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( |
|
|
|
|
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)] |
|
|
|
|
ranks = [ |
|
|
|
|
h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) |
|
|
|
|
for k in range(self.tesseract_dep) |
|
|
|
|
] |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
|
|
|
|
|
if self.rank in ranks: |
|
|
|
@ -193,10 +189,7 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
|
|
|
|
:type tesseract_dep: int |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
tesseract_dim: int, |
|
|
|
|
tesseract_dep: int, |
|
|
|
|
*args): |
|
|
|
|
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): |
|
|
|
|
super(Initializer_2p5D_XZ, self).__init__(*args) |
|
|
|
|
self.num_group = self.world_size // self.tensor_parallel_size |
|
|
|
|
self.tesseract_dep = tesseract_dep |
|
|
|
@ -218,9 +211,11 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
|
|
|
|
|
|
|
|
|
for h in range(self.num_group): |
|
|
|
|
for i in range(self.tesseract_dim): |
|
|
|
|
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( |
|
|
|
|
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in |
|
|
|
|
range(self.tesseract_dim)] |
|
|
|
|
ranks = [ |
|
|
|
|
h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) |
|
|
|
|
for k in range(self.tesseract_dep) |
|
|
|
|
for j in range(self.tesseract_dim) |
|
|
|
|
] |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
|
|
|
|
|
if self.rank in ranks: |
|
|
|
@ -253,15 +248,8 @@ class Initializer_2p5D(ProcessGroupInitializer):
|
|
|
|
|
:type depth: int |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
rank: int, |
|
|
|
|
world_size: int, |
|
|
|
|
config: Config, |
|
|
|
|
data_parallel_size: int, |
|
|
|
|
pipeline_parallel_size: int, |
|
|
|
|
tensor_parallel_size: int, |
|
|
|
|
depth: int |
|
|
|
|
): |
|
|
|
|
def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, |
|
|
|
|
tensor_parallel_size: int, depth: int): |
|
|
|
|
args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size) |
|
|
|
|
super().__init__(*args) |
|
|
|
|
self.num_group = self.world_size // self.tensor_parallel_size |
|
|
|
@ -279,10 +267,13 @@ class Initializer_2p5D(ProcessGroupInitializer):
|
|
|
|
|
|
|
|
|
|
def init_dist_group(self): |
|
|
|
|
"""Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu. |
|
|
|
|
|
|
|
|
|
:return: Whole 2p5D tensor parallelism's information |
|
|
|
|
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) |
|
|
|
|
""" |
|
|
|
|
parallel_setting = [self.col_initializer.init_dist_group(), self.row_initializer.init_dist_group(), |
|
|
|
|
self.dep_initializer.init_dist_group(), self.xz_initializer.init_dist_group()] |
|
|
|
|
parallel_setting = [ |
|
|
|
|
self.col_initializer.init_dist_group(), |
|
|
|
|
self.row_initializer.init_dist_group(), |
|
|
|
|
self.dep_initializer.init_dist_group(), |
|
|
|
|
self.xz_initializer.init_dist_group() |
|
|
|
|
] |
|
|
|
|
return parallel_setting |
|
|
|
|