From f83c4d6597008dfebe5a9e48348490cee676c757 Mon Sep 17 00:00:00 2001 From: runluo <68489000+run-qiao@users.noreply.github.com> Date: Wed, 13 Jul 2022 19:01:07 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/layer/wrapper/pipeline_wrapper.py code style (#1303) --- colossalai/nn/layer/wrapper/pipeline_wrapper.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/nn/layer/wrapper/pipeline_wrapper.py index 20813dc89..ef1d794cc 100644 --- a/colossalai/nn/layer/wrapper/pipeline_wrapper.py +++ b/colossalai/nn/layer/wrapper/pipeline_wrapper.py @@ -6,6 +6,7 @@ from colossalai.core import global_context as gpc class PipelineSharedModuleWrapper: + def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}' self.pipeline_ranks = pipeline_ranks @@ -22,10 +23,7 @@ class PipelineSharedModuleWrapper: num_pp_stages = num_dp_groups // pp_size for i in range(dp_size): for j in range(num_pp_stages): - pipeline_ranks = list( - range(i * num_dp_groups + j, - (i + 1) * num_dp_groups, - num_pp_stages)) + pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages)) sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks] group = dist.new_group(sub_ranks) if rank in sub_ranks: