from typing import List, Tuple, Union import torch.distributed as dist import torch.nn as nn from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc class PipelineSharedModuleWrapper: def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: assert len(pipeline_ranks) > 1, f"Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}" self.pipeline_ranks = pipeline_ranks self.group = None self.ranks_in_group = None self._init_group() def _init_group(self): world_size = gpc.get_world_size(ParallelMode.GLOBAL) dp_size = gpc.get_world_size(ParallelMode.DATA) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) rank = gpc.get_global_rank() num_dp_groups = world_size // dp_size num_pp_stages = num_dp_groups // pp_size for i in range(dp_size): for j in range(num_pp_stages): pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages)) sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks] group = dist.new_group(sub_ranks) if rank in sub_ranks: self.group = group self.ranks_in_group = sub_ranks def register_module(self, module: nn.Module): assert ( self.ranks_in_group is not None ), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}" src = self.ranks_in_group[self.pipeline_ranks[0]] for p in module.parameters(): setattr(p, "pipeline_shared_module_pg", self.group) dist.broadcast(p, src, group=self.group) def register_parameter(self, param: nn.Parameter): assert ( self.ranks_in_group is not None ), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}" src = self.ranks_in_group[self.pipeline_ranks[0]] setattr(param, "pipeline_shared_module_pg", self.group) dist.broadcast(param, src, group=self.group)