mirror of https://github.com/hpcaitech/ColossalAI
[moe] initialize MoE groups by ProcessGroup (#1640)
parent
e57df80325
commit
95c35f73bd
|
@ -3,6 +3,7 @@ import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
from colossalai.context.singleton_meta import SingletonMeta
|
||||||
|
from colossalai.tensor import ProcessGroup
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
@ -22,41 +23,9 @@ class MoeParallelInfo:
|
||||||
_check_sanity()
|
_check_sanity()
|
||||||
self.ep_size = ep_size
|
self.ep_size = ep_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.ep_group = None
|
self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
|
||||||
# data parallel group for experts, since ep_group is different
|
self.ep_group = self.pg.tp_process_group()
|
||||||
# we may have different dp_group from get_group(ParallelMode.DATA)
|
self.dp_group = self.pg.dp_process_group()
|
||||||
self.dp_group = None
|
|
||||||
|
|
||||||
# Here we assume tensor parallel size = 1
|
|
||||||
# Otherwise, MoE can't be used
|
|
||||||
# Since TENSOR parallel group and DATA parallel group
|
|
||||||
# have been created, we can use them directly.
|
|
||||||
if ep_size == 1:
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
self.ep_group = gpc.get_group(ParallelMode.TENSOR)
|
|
||||||
self.dp_group = gpc.get_group(ParallelMode.DATA)
|
|
||||||
return
|
|
||||||
|
|
||||||
if dp_size == 1:
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
self.ep_group = gpc.get_group(ParallelMode.DATA)
|
|
||||||
self.dp_group = gpc.get_group(ParallelMode.TENSOR)
|
|
||||||
return
|
|
||||||
|
|
||||||
rank = dist.get_rank()
|
|
||||||
# Create expert parallel group
|
|
||||||
for i in range(dp_size):
|
|
||||||
ranks = [i * ep_size + j for j in range(ep_size)]
|
|
||||||
group = dist.new_group(ranks)
|
|
||||||
if rank in ranks:
|
|
||||||
self.ep_group = group
|
|
||||||
|
|
||||||
# Create data parallel group
|
|
||||||
for j in range(ep_size):
|
|
||||||
ranks = [i * ep_size + j for i in range(dp_size)]
|
|
||||||
group = dist.new_group(ranks)
|
|
||||||
if rank in ranks:
|
|
||||||
self.dp_group = group
|
|
||||||
|
|
||||||
|
|
||||||
class MoeContext(metaclass=SingletonMeta):
|
class MoeContext(metaclass=SingletonMeta):
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.testing import parameterize
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.context import MOE_CONTEXT
|
||||||
|
from colossalai.tensor import ColoParameter
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
from tests.test_zero.common import CONFIG
|
||||||
|
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||||
|
from tests.test_tensor.common_utils import debug_print
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||||
|
def exam_moe_colo_init(init_device_type):
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
if init_device_type == 'cuda':
|
||||||
|
init_device = get_current_device()
|
||||||
|
elif init_device_type == 'cpu':
|
||||||
|
init_device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unknown device found.")
|
||||||
|
|
||||||
|
with ColoInitContext(device=init_device):
|
||||||
|
model = MoeModel(checkpoint=True)
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name)
|
||||||
|
|
||||||
|
if hasattr(param, "moe_info"):
|
||||||
|
param.set_process_group(param.moe_info.pg)
|
||||||
|
|
||||||
|
if hasattr(param, "moe_info"):
|
||||||
|
assert param.process_group.dp_world_size() == param.moe_info.dp_size
|
||||||
|
else:
|
||||||
|
assert param.process_group.dp_world_size() == world_size
|
||||||
|
|
||||||
|
|
||||||
|
def _run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
MOE_CONTEXT.setup(seed=42)
|
||||||
|
exam_moe_colo_init()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_moe_colo_init(world_size):
|
||||||
|
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_moe_colo_init(world_size=4)
|
Loading…
Reference in New Issue