mirror of https://github.com/hpcaitech/ColossalAI
[format] polish name format for MOE (#481)
parent
8d3250d74b
commit
65c0f380c2
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from .parallel_mode import ParallelMode
|
from .parallel_mode import ParallelMode
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
def _check_sanity():
|
def _check_sanity():
|
||||||
|
@ -10,7 +11,7 @@ def _check_sanity():
|
||||||
"pipeline parallel at present.")
|
"pipeline parallel at present.")
|
||||||
|
|
||||||
|
|
||||||
class MoeInfo:
|
class MoeParallelInfo:
|
||||||
"""Moe parallelism information, storing parallel sizes and groups.
|
"""Moe parallelism information, storing parallel sizes and groups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -78,11 +79,11 @@ class MoeContext:
|
||||||
self.use_kernel_optim = True
|
self.use_kernel_optim = True
|
||||||
|
|
||||||
self.has_setup = False
|
self.has_setup = False
|
||||||
self._info_dict = dict()
|
self._parallel_info_dict = dict()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def information(self):
|
def parallel_info_dict(self):
|
||||||
return self._info_dict
|
return self._parallel_info_dict
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_initialized(self):
|
def is_initialized(self):
|
||||||
|
@ -110,17 +111,27 @@ class MoeContext:
|
||||||
moe_set_seed(seed)
|
moe_set_seed(seed)
|
||||||
self.has_setup = True
|
self.has_setup = True
|
||||||
|
|
||||||
def get_info(self, num_experts: int):
|
def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
|
||||||
"""Automatically deploys experts and returns parallel infomation about
|
"""Calculate the Data Parallel Group and Expert Parallel Group.
|
||||||
distributed communication groups.
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_experts : int
|
||||||
|
The number experts
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int, MoeParallelInfo
|
||||||
|
number of local experts, the MoeParallelInfo of the current ep_size
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
|
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
|
||||||
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
|
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
|
||||||
|
|
||||||
assert gt_flag or lt_flag, "Automatic experts placement do not support such situation right now."
|
assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number"\
|
||||||
|
" is not a multiple of ep size or vice versa."
|
||||||
|
|
||||||
# If the number of experts is greater than maximum expert parallel size,
|
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
|
||||||
# there are multiple experts in each GPU and each GPU has different experts
|
# there are multiple experts in each GPU and each GPU has different experts
|
||||||
# So it's data parallel size is 1
|
# So it's data parallel size is 1
|
||||||
# Otherwise, there is only one expert in each GPU
|
# Otherwise, there is only one expert in each GPU
|
||||||
|
@ -133,10 +144,10 @@ class MoeContext:
|
||||||
|
|
||||||
# Don't forget to multiply minimum data parallel size
|
# Don't forget to multiply minimum data parallel size
|
||||||
dp_size *= self.min_dp_size
|
dp_size *= self.min_dp_size
|
||||||
if not (ep_size in self.information):
|
if not (ep_size in self.parallel_info_dict):
|
||||||
self.information[ep_size] = MoeInfo(ep_size, dp_size)
|
self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
|
||||||
|
|
||||||
return num_local_experts, self.information[ep_size]
|
return num_local_experts, self.parallel_info_dict[ep_size]
|
||||||
|
|
||||||
def set_kernel_not_use(self):
|
def set_kernel_not_use(self):
|
||||||
self.use_kernel_optim = False
|
self.use_kernel_optim = False
|
||||||
|
|
|
@ -31,4 +31,5 @@ class MoeGradientHandler(BaseGradientHandler):
|
||||||
|
|
||||||
for ep_size in param_dict:
|
for ep_size in param_dict:
|
||||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||||
bucket_allreduce(param_list=param_dict[ep_size], group=MOE_CONTEXT.information[ep_size].dp_group)
|
bucket_allreduce(param_list=param_dict[ep_size],
|
||||||
|
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||||
from colossalai.context import ParallelMode, seed
|
from colossalai.context import ParallelMode, seed
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.core import MOE_CONTEXT
|
from colossalai.core import MOE_CONTEXT
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
|
||||||
class MoeExperts(nn.Module):
|
class MoeExperts(nn.Module):
|
||||||
|
@ -34,12 +35,12 @@ class Experts(MoeExperts):
|
||||||
:type num_experts: int
|
:type num_experts: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, expert, num_experts, **expert_args):
|
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
||||||
super().__init__("all_to_all", num_experts)
|
super().__init__("all_to_all", num_experts)
|
||||||
|
|
||||||
# Use seed to make every expert different from others
|
# Use seed to make every expert different from others
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(self.num_local_experts)])
|
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
|
||||||
|
|
||||||
# Attach parallel information for all parameters in Experts
|
# Attach parallel information for all parameters in Experts
|
||||||
for exp in self.experts:
|
for exp in self.experts:
|
||||||
|
|
|
@ -46,6 +46,6 @@ def sync_moe_model_param(model: nn.Module):
|
||||||
for ep_size in param_dict:
|
for ep_size in param_dict:
|
||||||
# When ep_size = world_size, communication is not needed
|
# When ep_size = world_size, communication is not needed
|
||||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||||
src_rank = dist.get_rank(MOE_CONTEXT.information[ep_size].ep_group)
|
src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group)
|
||||||
for param in param_dict[ep_size]:
|
for param in param_dict[ep_size]:
|
||||||
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
|
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
|
||||||
|
|
|
@ -36,7 +36,7 @@ def run_test(rank, world_size, port):
|
||||||
model = model.to(get_current_device())
|
model = model.to(get_current_device())
|
||||||
sync_moe_model_param(model)
|
sync_moe_model_param(model)
|
||||||
|
|
||||||
dist_dict = MOE_CONTEXT.information
|
dist_dict = MOE_CONTEXT.parallel_info_dict
|
||||||
assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group)
|
assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group)
|
||||||
assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group)
|
assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group)
|
||||||
# MoE model synchronization passed
|
# MoE model synchronization passed
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -16,7 +15,8 @@ D_FF = 8
|
||||||
CONFIG = dict()
|
CONFIG = dict()
|
||||||
|
|
||||||
|
|
||||||
def run_test(rank, world_size, port):
|
def run_test(rank, port):
|
||||||
|
world_size = 4
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
expert_module = nn.Linear
|
expert_module = nn.Linear
|
||||||
expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())
|
expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())
|
||||||
|
@ -33,36 +33,36 @@ def run_test(rank, world_size, port):
|
||||||
assert exp3.num_local_experts == 2
|
assert exp3.num_local_experts == 2
|
||||||
# experts deployment passed
|
# experts deployment passed
|
||||||
|
|
||||||
dist_dict = MOE_CONTEXT.information
|
parallel_info_dict = MOE_CONTEXT.parallel_info_dict
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
assert len(dist_dict) == 3
|
assert len(parallel_info_dict) == 3
|
||||||
assert dist.get_rank(dist_dict[4].ep_group) == rank
|
assert dist.get_rank(parallel_info_dict[4].ep_group) == rank
|
||||||
assert dist.get_rank(dist_dict[2].ep_group) == rank % 2
|
assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2
|
||||||
assert dist.get_rank(dist_dict[1].ep_group) == 0
|
assert dist.get_rank(parallel_info_dict[1].ep_group) == 0
|
||||||
|
|
||||||
assert dist.get_rank(dist_dict[4].dp_group) == 0
|
assert dist.get_rank(parallel_info_dict[4].dp_group) == 0
|
||||||
assert dist.get_rank(dist_dict[2].dp_group) == rank // 2
|
assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2
|
||||||
assert dist.get_rank(dist_dict[1].dp_group) == rank
|
assert dist.get_rank(parallel_info_dict[1].dp_group) == rank
|
||||||
# group creation passed
|
# group creation passed
|
||||||
|
|
||||||
model = nn.ModuleList([exp0, exp1, exp2, exp3])
|
model = nn.ModuleList([exp0, exp1, exp2, exp3])
|
||||||
model = model.to(get_current_device())
|
model = model.to(get_current_device())
|
||||||
sync_moe_model_param(model)
|
sync_moe_model_param(model)
|
||||||
|
|
||||||
assert_equal_in_group(exp0.experts[0].weight.data, dist_dict[1].dp_group)
|
assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group)
|
||||||
assert_equal_in_group(exp0.experts[0].bias.data, dist_dict[1].dp_group)
|
assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group)
|
||||||
# MOE experts layout success when ep_size = 1
|
# MOE experts layout success when ep_size = 1
|
||||||
|
|
||||||
assert_equal_in_group(exp1.experts[0].weight.data, dist_dict[2].dp_group)
|
assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group)
|
||||||
assert_equal_in_group(exp1.experts[0].bias.data, dist_dict[2].dp_group)
|
assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group)
|
||||||
# MOE experts layout success when ep_size = 2
|
# MOE experts layout success when ep_size = 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_moe_initialization():
|
def test_moe_initialization():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_test, world_size=world_size, port=free_port())
|
run_func = partial(run_test, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue