[NFC] polish colossalai/nn/layer/wrapper/pipeline_wrapper.py code style (#1303)

pull/1305/head
runluo 2022-07-13 19:01:07 +08:00 committed by GitHub
parent 7696cead8d
commit f83c4d6597
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 4 deletions

View File

@ -6,6 +6,7 @@ from colossalai.core import global_context as gpc
class PipelineSharedModuleWrapper:
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
self.pipeline_ranks = pipeline_ranks
@ -22,10 +23,7 @@ class PipelineSharedModuleWrapper:
num_pp_stages = num_dp_groups // pp_size
for i in range(dp_size):
for j in range(num_pp_stages):
pipeline_ranks = list(
range(i * num_dp_groups + j,
(i + 1) * num_dp_groups,
num_pp_stages))
pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages))
sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks]
group = dist.new_group(sub_ranks)
if rank in sub_ranks: