mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/layer/wrapper/pipeline_wrapper.py code style (#1303)
parent
7696cead8d
commit
f83c4d6597
|
@ -6,6 +6,7 @@ from colossalai.core import global_context as gpc
|
|||
|
||||
|
||||
class PipelineSharedModuleWrapper:
|
||||
|
||||
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
|
||||
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
|
||||
self.pipeline_ranks = pipeline_ranks
|
||||
|
@ -22,10 +23,7 @@ class PipelineSharedModuleWrapper:
|
|||
num_pp_stages = num_dp_groups // pp_size
|
||||
for i in range(dp_size):
|
||||
for j in range(num_pp_stages):
|
||||
pipeline_ranks = list(
|
||||
range(i * num_dp_groups + j,
|
||||
(i + 1) * num_dp_groups,
|
||||
num_pp_stages))
|
||||
pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages))
|
||||
sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks]
|
||||
group = dist.new_group(sub_ranks)
|
||||
if rank in sub_ranks:
|
||||
|
|
Loading…
Reference in New Issue