mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] add stage manager (#4093)
* [pipeline] add stage manager * [test] add pipeline stage manager test * [pipeline] add docstring for stage managerpull/4445/head
parent
5e1a9d48dd
commit
422544222f
|
@ -0,0 +1,176 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStageManager:
|
||||||
|
"""PipelineStageManager is a helper class to manage pipeline stages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pg_mesh (ProcessGroupMesh): Process group mesh.
|
||||||
|
pipeline_axis (int): The axis along which the pipeline is constructed.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_stages (int): Number of stages in the pipeline.
|
||||||
|
stage (int): The current stage.
|
||||||
|
num_virtual_stages (int): Number of virtual stages in the pipeline.
|
||||||
|
virtual_stage (int): The current virtual stage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None:
|
||||||
|
self.pg_mesh = pg_mesh
|
||||||
|
self.pipeline_axis = pipeline_axis
|
||||||
|
self.num_virtual_stages: Optional[int] = None
|
||||||
|
self.virtual_stage: Optional[int] = None
|
||||||
|
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||||
|
self.next_rank: Optional[Tuple[int, ...]] = None
|
||||||
|
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
|
||||||
|
# init prev and next coord
|
||||||
|
coord = self.pg_mesh.coordinate()
|
||||||
|
if self.stage > 0:
|
||||||
|
prev_coord = coord[: self.pipeline_axis] + \
|
||||||
|
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
|
||||||
|
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape)
|
||||||
|
if self.stage < self.num_stages - 1:
|
||||||
|
next_coord = coord[: self.pipeline_axis] + \
|
||||||
|
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
|
||||||
|
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape)
|
||||||
|
|
||||||
|
# init p2p process groups
|
||||||
|
stages = list(range(self.num_stages))
|
||||||
|
for prev, cur in zip(stages[:-1], stages[1:]):
|
||||||
|
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
|
||||||
|
if self.stage in [prev, cur]:
|
||||||
|
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||||
|
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||||
|
|
||||||
|
def is_first_stage(self, virtual: bool = False) -> bool:
|
||||||
|
"""Is the current stage the first stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the current stage is the first stage.
|
||||||
|
"""
|
||||||
|
if virtual:
|
||||||
|
assert self.num_virtual_stages is not None
|
||||||
|
return self.virtual_stage == 0
|
||||||
|
return self.stage == 0
|
||||||
|
|
||||||
|
def is_last_stage(self, virtual: bool = False) -> bool:
|
||||||
|
"""Is the current stage the last stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the current stage is the last stage.
|
||||||
|
"""
|
||||||
|
if virtual:
|
||||||
|
assert self.num_virtual_stages is not None
|
||||||
|
return self.virtual_stage == self.num_virtual_stages - 1
|
||||||
|
return self.stage == self.num_stages - 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_stages(self) -> int:
|
||||||
|
"""Number of stages in the pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of stages in the pipeline.
|
||||||
|
"""
|
||||||
|
return self.pg_mesh.size(self.pipeline_axis)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stage(self) -> int:
|
||||||
|
"""Current stage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Current stage.
|
||||||
|
"""
|
||||||
|
return self.pg_mesh.coordinate(self.pipeline_axis)
|
||||||
|
|
||||||
|
def get_rank(self) -> int:
|
||||||
|
"""Get the rank of the current process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Rank of the current process.
|
||||||
|
"""
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
def get_prev_rank(self) -> int:
|
||||||
|
"""Get the rank of the previous stage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Rank of the previous stage.
|
||||||
|
"""
|
||||||
|
assert not self.is_first_stage(), "Cannot get previous rank in the first stage."
|
||||||
|
return self.prev_rank
|
||||||
|
|
||||||
|
def get_next_rank(self) -> int:
|
||||||
|
"""Get the rank of the next stage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Rank of the next stage.
|
||||||
|
"""
|
||||||
|
assert not self.is_last_stage(), "Cannot get next rank in the last stage."
|
||||||
|
return self.next_rank
|
||||||
|
|
||||||
|
def set_num_virtual_stages(self, num_virtual_stages: int) -> None:
|
||||||
|
"""Set the number of virtual stages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_virtual_stages (int): Number of virtual stages.
|
||||||
|
"""
|
||||||
|
self.num_virtual_stages = num_virtual_stages
|
||||||
|
|
||||||
|
def set_virtual_stage(self, virtual_stage: int) -> None:
|
||||||
|
"""Set the virtual stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
virtual_stage (int): Virtual stage.
|
||||||
|
"""
|
||||||
|
self.virtual_stage = virtual_stage
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def switch_virtual_stage(self, virtual_stage: int) -> None:
|
||||||
|
"""A context manager to switch virtual stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
virtual_stage (int): Target virtual stage.
|
||||||
|
"""
|
||||||
|
old_stage = self.virtual_stage
|
||||||
|
try:
|
||||||
|
self.set_virtual_stage(virtual_stage)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.set_virtual_stage(old_stage)
|
||||||
|
|
||||||
|
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
|
||||||
|
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
first_rank (int): The first rank.
|
||||||
|
second_rank (int): The second rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProcessGroup: P2P process group between the two ranks.
|
||||||
|
"""
|
||||||
|
if first_rank > second_rank:
|
||||||
|
first_rank, second_rank = second_rank, first_rank
|
||||||
|
return self.p2p_groups[(first_rank, second_rank)]
|
||||||
|
|
||||||
|
def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
|
||||||
|
"""Get the process group of the given stages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stages (List[int]): List of stages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProcessGroup: Process group of the given stages.
|
||||||
|
"""
|
||||||
|
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
|
|
@ -0,0 +1,86 @@
|
||||||
|
import pytest
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.testing import spawn
|
||||||
|
|
||||||
|
|
||||||
|
def check_stage_manager():
|
||||||
|
DP_DIM, PP_DIM = 0, 1
|
||||||
|
DP_SIZE, PP_SIZE = 2, 2
|
||||||
|
RANK_TO_COORDINATE = {
|
||||||
|
0: (0, 0),
|
||||||
|
1: (0, 1),
|
||||||
|
2: (1, 0),
|
||||||
|
3: (1, 1),
|
||||||
|
}
|
||||||
|
PP_RANKS_IN_GROUP = {
|
||||||
|
0: [0, 1],
|
||||||
|
1: [0, 1],
|
||||||
|
2: [2, 3],
|
||||||
|
3: [2, 3],
|
||||||
|
}
|
||||||
|
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||||
|
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
# check stage info
|
||||||
|
assert stage_manager.num_stages == PP_SIZE
|
||||||
|
assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM]
|
||||||
|
|
||||||
|
# check is_first_stage
|
||||||
|
ranks_in_group = PP_RANKS_IN_GROUP[rank]
|
||||||
|
is_first_stage = ranks_in_group.index(rank) == 0
|
||||||
|
assert stage_manager.is_first_stage() == is_first_stage
|
||||||
|
|
||||||
|
# check is_last_stage
|
||||||
|
is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1
|
||||||
|
assert stage_manager.is_last_stage() == is_last_stage
|
||||||
|
|
||||||
|
# check prev rank
|
||||||
|
if not is_first_stage:
|
||||||
|
prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1]
|
||||||
|
assert stage_manager.get_prev_rank() == prev_rank
|
||||||
|
|
||||||
|
# check next rank
|
||||||
|
if not is_last_stage:
|
||||||
|
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1]
|
||||||
|
assert stage_manager.get_next_rank() == next_rank
|
||||||
|
|
||||||
|
# check virtual stage
|
||||||
|
stage_manager.set_num_virtual_stages(PP_SIZE * 2)
|
||||||
|
assert stage_manager.num_virtual_stages == PP_SIZE * 2
|
||||||
|
stage_manager.set_virtual_stage(stage_manager.stage * 2)
|
||||||
|
assert stage_manager.virtual_stage == stage_manager.stage * 2
|
||||||
|
with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1):
|
||||||
|
assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1
|
||||||
|
assert stage_manager.virtual_stage == stage_manager.stage * 2
|
||||||
|
|
||||||
|
# check p2p groups
|
||||||
|
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
|
||||||
|
if rank in [prev, cur]:
|
||||||
|
group = stage_manager.get_p2p_process_group(prev, cur)
|
||||||
|
dist.barrier(group=group)
|
||||||
|
|
||||||
|
# check stage groups
|
||||||
|
pg_mesh = ProcessGroupMesh(4)
|
||||||
|
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||||||
|
group = stage_manager.init_process_group_by_stages([0, 2])
|
||||||
|
if rank in [0, 2]:
|
||||||
|
dist.barrier(group=group)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
check_stage_manager()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_process_group_mesh():
|
||||||
|
spawn(run_dist, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_process_group_mesh()
|
Loading…
Reference in New Issue