mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/layer/wrapper/pipeline_wrapper.py code style (#1303)
parent
7696cead8d
commit
f83c4d6597
|
@ -6,6 +6,7 @@ from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
class PipelineSharedModuleWrapper:
|
class PipelineSharedModuleWrapper:
|
||||||
|
|
||||||
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
|
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
|
||||||
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
|
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
|
||||||
self.pipeline_ranks = pipeline_ranks
|
self.pipeline_ranks = pipeline_ranks
|
||||||
|
@ -22,10 +23,7 @@ class PipelineSharedModuleWrapper:
|
||||||
num_pp_stages = num_dp_groups // pp_size
|
num_pp_stages = num_dp_groups // pp_size
|
||||||
for i in range(dp_size):
|
for i in range(dp_size):
|
||||||
for j in range(num_pp_stages):
|
for j in range(num_pp_stages):
|
||||||
pipeline_ranks = list(
|
pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages))
|
||||||
range(i * num_dp_groups + j,
|
|
||||||
(i + 1) * num_dp_groups,
|
|
||||||
num_pp_stages))
|
|
||||||
sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks]
|
sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks]
|
||||||
group = dist.new_group(sub_ranks)
|
group = dist.new_group(sub_ranks)
|
||||||
if rank in sub_ranks:
|
if rank in sub_ranks:
|
||||||
|
|
Loading…
Reference in New Issue