from typing import Dict, List import torch.distributed as dist import torch.nn as nn from colossalai.context.moe_context import MOE_CONTEXT from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import is_using_ddp def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: """Returns a parameter dictionary, the key of which is the expert parallel size of every parameter. Since the parameters in data parallelism is replicated in each GPU, we set their ep_size to 1. Args: model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. """ epsize_param_dict = dict() for param in model.parameters(): if not hasattr(param, "moe_info"): ep_size = 1 # set ep_size to 1 for dp parameters else: ep_size = param.moe_info.ep_size if ep_size not in epsize_param_dict: epsize_param_dict[ep_size] = [] epsize_param_dict[ep_size].append(param) return epsize_param_dict def sync_moe_model_param(model: nn.Module): """Make sure model parameters are consistent in MoE parallel context. Args: model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. """ if is_using_ddp(): param_dict = get_moe_epsize_param_dict(model) # synchronize the parameters whose dp_group is the whole world if 1 in param_dict: src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] for param in param_dict[1]: dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) for ep_size in param_dict: # When ep_size = world_size, communication is not needed if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) for param in param_dict[ep_size]: dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)