[moe] initialize MoE groups by ProcessGroup (#1640)

pull/1642/head
HELSON 2022-09-23 17:20:41 +08:00 committed by GitHub
parent e57df80325
commit 95c35f73bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 35 deletions

View File

@ -3,6 +3,7 @@ import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor import ProcessGroup
from typing import Tuple
@ -22,41 +23,9 @@ class MoeParallelInfo:
_check_sanity()
self.ep_size = ep_size
self.dp_size = dp_size
self.ep_group = None
# data parallel group for experts, since ep_group is different
# we may have different dp_group from get_group(ParallelMode.DATA)
self.dp_group = None
# Here we assume tensor parallel size = 1
# Otherwise, MoE can't be used
# Since TENSOR parallel group and DATA parallel group
# have been created, we can use them directly.
if ep_size == 1:
from colossalai.core import global_context as gpc
self.ep_group = gpc.get_group(ParallelMode.TENSOR)
self.dp_group = gpc.get_group(ParallelMode.DATA)
return
if dp_size == 1:
from colossalai.core import global_context as gpc
self.ep_group = gpc.get_group(ParallelMode.DATA)
self.dp_group = gpc.get_group(ParallelMode.TENSOR)
return
rank = dist.get_rank()
# Create expert parallel group
for i in range(dp_size):
ranks = [i * ep_size + j for j in range(ep_size)]
group = dist.new_group(ranks)
if rank in ranks:
self.ep_group = group
# Create data parallel group
for j in range(ep_size):
ranks = [i * ep_size + j for i in range(dp_size)]
group = dist.new_group(ranks)
if rank in ranks:
self.dp_group = group
self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
self.ep_group = self.pg.tp_process_group()
self.dp_group = self.pg.dp_process_group()
class MoeContext(metaclass=SingletonMeta):

View File

@ -0,0 +1,63 @@
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.context import MOE_CONTEXT
from colossalai.tensor import ColoParameter
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import get_current_device
from tests.test_zero.common import CONFIG
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_tensor.common_utils import debug_print
@parameterize("init_device_type", ['cpu', 'cuda'])
def exam_moe_colo_init(init_device_type):
world_size = dist.get_world_size()
if init_device_type == 'cuda':
init_device = get_current_device()
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:
raise NotImplementedError("Unknown device found.")
with ColoInitContext(device=init_device):
model = MoeModel(checkpoint=True)
for name, param in model.named_parameters():
assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name)
if hasattr(param, "moe_info"):
param.set_process_group(param.moe_info.pg)
if hasattr(param, "moe_info"):
assert param.process_group.dp_world_size() == param.moe_info.dp_size
else:
assert param.process_group.dp_world_size() == world_size
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
exam_moe_colo_init()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_colo_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_colo_init(world_size=4)