From 95c35f73bd4ab65e89ee3f3e3dc2936cff01bfc7 Mon Sep 17 00:00:00 2001 From: HELSON Date: Fri, 23 Sep 2022 17:20:41 +0800 Subject: [PATCH] [moe] initialize MoE groups by ProcessGroup (#1640) --- colossalai/context/moe_context.py | 39 ++--------------- tests/test_moe/test_moe_colo_init.py | 63 ++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 35 deletions(-) create mode 100644 tests/test_moe/test_moe_colo_init.py diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index 23eec6186..b36b88455 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -3,6 +3,7 @@ import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor import ProcessGroup from typing import Tuple @@ -22,41 +23,9 @@ class MoeParallelInfo: _check_sanity() self.ep_size = ep_size self.dp_size = dp_size - self.ep_group = None - # data parallel group for experts, since ep_group is different - # we may have different dp_group from get_group(ParallelMode.DATA) - self.dp_group = None - - # Here we assume tensor parallel size = 1 - # Otherwise, MoE can't be used - # Since TENSOR parallel group and DATA parallel group - # have been created, we can use them directly. - if ep_size == 1: - from colossalai.core import global_context as gpc - self.ep_group = gpc.get_group(ParallelMode.TENSOR) - self.dp_group = gpc.get_group(ParallelMode.DATA) - return - - if dp_size == 1: - from colossalai.core import global_context as gpc - self.ep_group = gpc.get_group(ParallelMode.DATA) - self.dp_group = gpc.get_group(ParallelMode.TENSOR) - return - - rank = dist.get_rank() - # Create expert parallel group - for i in range(dp_size): - ranks = [i * ep_size + j for j in range(ep_size)] - group = dist.new_group(ranks) - if rank in ranks: - self.ep_group = group - - # Create data parallel group - for j in range(ep_size): - ranks = [i * ep_size + j for i in range(dp_size)] - group = dist.new_group(ranks) - if rank in ranks: - self.dp_group = group + self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) + self.ep_group = self.pg.tp_process_group() + self.dp_group = self.pg.dp_process_group() class MoeContext(metaclass=SingletonMeta): diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py new file mode 100644 index 000000000..ae0c1390c --- /dev/null +++ b/tests/test_moe/test_moe_colo_init.py @@ -0,0 +1,63 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import parameterize +from colossalai.utils import free_port +from colossalai.context import MOE_CONTEXT +from colossalai.tensor import ColoParameter +from colossalai.utils.model.colo_init_context import ColoInitContext + +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import get_current_device + +from tests.test_zero.common import CONFIG +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_tensor.common_utils import debug_print + + +@parameterize("init_device_type", ['cpu', 'cuda']) +def exam_moe_colo_init(init_device_type): + world_size = dist.get_world_size() + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + with ColoInitContext(device=init_device): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) + + if hasattr(param, "moe_info"): + param.set_process_group(param.moe_info.pg) + + if hasattr(param, "moe_info"): + assert param.process_group.dp_world_size() == param.moe_info.dp_size + else: + assert param.process_group.dp_world_size() == world_size + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + exam_moe_colo_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_colo_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_colo_init(world_size=4)