Browse Source

[format] polish name format for MOE (#481)

pull/483/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
65c0f380c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 35
      colossalai/context/moe_context.py
  2. 3
      colossalai/engine/gradient_handler/_moe_gradient_handler.py
  3. 5
      colossalai/nn/layer/moe/experts.py
  4. 2
      colossalai/utils/moe.py
  5. 2
      tests/test_moe/test_grad_handler.py
  6. 30
      tests/test_moe/test_moe_group.py

35
colossalai/context/moe_context.py

@ -1,6 +1,7 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from typing import Tuple
def _check_sanity(): def _check_sanity():
@ -10,7 +11,7 @@ def _check_sanity():
"pipeline parallel at present.") "pipeline parallel at present.")
class MoeInfo: class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups. """Moe parallelism information, storing parallel sizes and groups.
""" """
@ -78,11 +79,11 @@ class MoeContext:
self.use_kernel_optim = True self.use_kernel_optim = True
self.has_setup = False self.has_setup = False
self._info_dict = dict() self._parallel_info_dict = dict()
@property @property
def information(self): def parallel_info_dict(self):
return self._info_dict return self._parallel_info_dict
@property @property
def is_initialized(self): def is_initialized(self):
@ -110,17 +111,27 @@ class MoeContext:
moe_set_seed(seed) moe_set_seed(seed)
self.has_setup = True self.has_setup = True
def get_info(self, num_experts: int): def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
"""Automatically deploys experts and returns parallel infomation about """Calculate the Data Parallel Group and Expert Parallel Group.
distributed communication groups.
Parameters
----------
num_experts : int
The number experts
Returns
-------
int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size
""" """
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, "Automatic experts placement do not support such situation right now." assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number"\
" is not a multiple of ep size or vice versa."
# If the number of experts is greater than maximum expert parallel size, # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts # there are multiple experts in each GPU and each GPU has different experts
# So it's data parallel size is 1 # So it's data parallel size is 1
# Otherwise, there is only one expert in each GPU # Otherwise, there is only one expert in each GPU
@ -133,10 +144,10 @@ class MoeContext:
# Don't forget to multiply minimum data parallel size # Don't forget to multiply minimum data parallel size
dp_size *= self.min_dp_size dp_size *= self.min_dp_size
if not (ep_size in self.information): if not (ep_size in self.parallel_info_dict):
self.information[ep_size] = MoeInfo(ep_size, dp_size) self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
return num_local_experts, self.information[ep_size] return num_local_experts, self.parallel_info_dict[ep_size]
def set_kernel_not_use(self): def set_kernel_not_use(self):
self.use_kernel_optim = False self.use_kernel_optim = False

3
colossalai/engine/gradient_handler/_moe_gradient_handler.py

@ -31,4 +31,5 @@ class MoeGradientHandler(BaseGradientHandler):
for ep_size in param_dict: for ep_size in param_dict:
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
bucket_allreduce(param_list=param_dict[ep_size], group=MOE_CONTEXT.information[ep_size].dp_group) bucket_allreduce(param_list=param_dict[ep_size],
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)

5
colossalai/nn/layer/moe/experts.py

@ -5,6 +5,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT from colossalai.core import MOE_CONTEXT
from typing import Type
class MoeExperts(nn.Module): class MoeExperts(nn.Module):
@ -34,12 +35,12 @@ class Experts(MoeExperts):
:type num_experts: int :type num_experts: int
""" """
def __init__(self, expert, num_experts, **expert_args): def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts) super().__init__("all_to_all", num_experts)
# Use seed to make every expert different from others # Use seed to make every expert different from others
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(self.num_local_experts)]) self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
# Attach parallel information for all parameters in Experts # Attach parallel information for all parameters in Experts
for exp in self.experts: for exp in self.experts:

2
colossalai/utils/moe.py

@ -46,6 +46,6 @@ def sync_moe_model_param(model: nn.Module):
for ep_size in param_dict: for ep_size in param_dict:
# When ep_size = world_size, communication is not needed # When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
src_rank = dist.get_rank(MOE_CONTEXT.information[ep_size].ep_group) src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group)
for param in param_dict[ep_size]: for param in param_dict[ep_size]:
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)

2
tests/test_moe/test_grad_handler.py

@ -36,7 +36,7 @@ def run_test(rank, world_size, port):
model = model.to(get_current_device()) model = model.to(get_current_device())
sync_moe_model_param(model) sync_moe_model_param(model)
dist_dict = MOE_CONTEXT.information dist_dict = MOE_CONTEXT.parallel_info_dict
assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group) assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group)
# MoE model synchronization passed # MoE model synchronization passed

30
tests/test_moe/test_moe_group.py

@ -1,6 +1,5 @@
from functools import partial from functools import partial
import pytest import pytest
import torch
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
@ -16,7 +15,8 @@ D_FF = 8
CONFIG = dict() CONFIG = dict()
def run_test(rank, world_size, port): def run_test(rank, port):
world_size = 4
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
expert_module = nn.Linear expert_module = nn.Linear
expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())
@ -33,36 +33,36 @@ def run_test(rank, world_size, port):
assert exp3.num_local_experts == 2 assert exp3.num_local_experts == 2
# experts deployment passed # experts deployment passed
dist_dict = MOE_CONTEXT.information parallel_info_dict = MOE_CONTEXT.parallel_info_dict
rank = dist.get_rank() rank = dist.get_rank()
assert len(dist_dict) == 3 assert len(parallel_info_dict) == 3
assert dist.get_rank(dist_dict[4].ep_group) == rank assert dist.get_rank(parallel_info_dict[4].ep_group) == rank
assert dist.get_rank(dist_dict[2].ep_group) == rank % 2 assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2
assert dist.get_rank(dist_dict[1].ep_group) == 0 assert dist.get_rank(parallel_info_dict[1].ep_group) == 0
assert dist.get_rank(dist_dict[4].dp_group) == 0 assert dist.get_rank(parallel_info_dict[4].dp_group) == 0
assert dist.get_rank(dist_dict[2].dp_group) == rank // 2 assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2
assert dist.get_rank(dist_dict[1].dp_group) == rank assert dist.get_rank(parallel_info_dict[1].dp_group) == rank
# group creation passed # group creation passed
model = nn.ModuleList([exp0, exp1, exp2, exp3]) model = nn.ModuleList([exp0, exp1, exp2, exp3])
model = model.to(get_current_device()) model = model.to(get_current_device())
sync_moe_model_param(model) sync_moe_model_param(model)
assert_equal_in_group(exp0.experts[0].weight.data, dist_dict[1].dp_group) assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group)
assert_equal_in_group(exp0.experts[0].bias.data, dist_dict[1].dp_group) assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group)
# MOE experts layout success when ep_size = 1 # MOE experts layout success when ep_size = 1
assert_equal_in_group(exp1.experts[0].weight.data, dist_dict[2].dp_group) assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group)
assert_equal_in_group(exp1.experts[0].bias.data, dist_dict[2].dp_group) assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group)
# MOE experts layout success when ep_size = 2 # MOE experts layout success when ep_size = 2
@pytest.mark.dist @pytest.mark.dist
def test_moe_initialization(): def test_moe_initialization():
world_size = 4 world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port()) run_func = partial(run_test, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

Loading…
Cancel
Save