[MOE] polish moe_env (#467)

pull/471/head
HELSON 3 years ago committed by GitHub
parent bccbc15861
commit aff9d354f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,4 +4,4 @@
from colossalai.context import ParallelContext, MoeContext from colossalai.context import ParallelContext, MoeContext
global_context = ParallelContext.get_instance() global_context = ParallelContext.get_instance()
moe_context = MoeContext.get_instance() MOE_CONTEXT = MoeContext.get_instance()

@ -1,4 +1,4 @@
from colossalai.core import global_context as gpc, moe_context as moe_env from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.registry import GRADIENT_HANDLER from colossalai.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict from colossalai.utils.moe import get_moe_epsize_param_dict
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
@ -30,5 +30,5 @@ class MoeGradientHandler(BaseGradientHandler):
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA)) bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
for ep_size in param_dict: for ep_size in param_dict:
if ep_size != 1 and ep_size != moe_env.world_size: if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
bucket_allreduce(param_list=param_dict[ep_size], group=moe_env.information[ep_size].dp_group) bucket_allreduce(param_list=param_dict[ep_size], group=MOE_CONTEXT.information[ep_size].dp_group)

@ -4,11 +4,11 @@ from torch import Tensor
from typing import Any, Tuple, Optional from typing import Any, Tuple, Optional
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
U_CUDA_MODE = False COL_MOE_KERNEL_FLAG = False
try: try:
import colossal_moe_cuda import colossal_moe_cuda
U_CUDA_MODE = True COL_MOE_KERNEL_FLAG = True
except ImportError: except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!") print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
@ -17,7 +17,6 @@ class AllGather(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None: if ctx is not None:
ctx.comm_grp = group ctx.comm_grp = group
@ -40,7 +39,6 @@ class ReduceScatter(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None: if ctx is not None:
ctx.comm_grp = group ctx.comm_grp = group
@ -149,7 +147,7 @@ class MoeCombine(torch.autograd.Function):
def moe_cumsum(inputs: Tensor): def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0) dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and U_CUDA_MODE: if flag and COL_MOE_KERNEL_FLAG:
return colossal_moe_cuda.cumsum_sub_one(inputs) return colossal_moe_cuda.cumsum_sub_one(inputs)
else: else:
return torch.cumsum(inputs, dim=0) - 1 return torch.cumsum(inputs, dim=0) - 1

@ -2,18 +2,24 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT
class MoeExperts(nn.Module): class MoeExperts(nn.Module):
"""Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups.
"""
def __init__(self, comm: str): def __init__(self, comm_name: str, num_experts: int):
super().__init__() super().__init__()
assert comm in {"all_to_all", "all_gather"}, \ assert comm_name in {"all_to_all", "all_gather"}, \
"This kind of communication has not been implemented yet.\n Please use Experts build function." "This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm = comm self.comm_name = comm_name
# Get the configuration of experts' deployment and parallel information from moe contex
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
class Experts(MoeExperts): class Experts(MoeExperts):
@ -29,53 +35,48 @@ class Experts(MoeExperts):
""" """
def __init__(self, expert, num_experts, **expert_args): def __init__(self, expert, num_experts, **expert_args):
super().__init__("all_to_all") super().__init__("all_to_all", num_experts)
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL): # Use seed to make every expert different from others
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)]) with seed(ParallelMode.TENSOR):
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(self.num_local_experts)])
# Attach parallel information for all parameters in Experts
for exp in self.experts: for exp in self.experts:
for param in exp.parameters(): for param in exp.parameters():
param.__setattr__('moe_param', True) param.__setattr__('moe_info', self.dist_info)
self.num_local_experts = num_local_experts
def forward(self, inputs): def forward(self, inputs: torch.Tensor):
# Split inputs for each expert
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = [] expert_output = []
# Get outputs from each expert
for i in range(self.num_local_experts): for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i])) expert_output.append(self.experts[i](expert_input[i]))
# Concatenate all outputs together
output = torch.cat(expert_output, dim=1).contiguous() output = torch.cat(expert_output, dim=1).contiguous()
return output return output
class FFNExperts(MoeExperts): class FFNExperts(MoeExperts):
"""Use torch.bmm to speed up for multiple experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_to_all") super().__init__("all_to_all", num_experts)
assert num_experts % moe_env.model_parallel_size == 0, \ self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
"The number of experts should be divied by moe model size" self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
num_local_experts = num_experts // moe_env.model_parallel_size self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model) s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff) s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.MOE_MODEL): with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1) nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1) nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2) nn.init.trunc_normal_(self.w2, std=s2)
@ -85,7 +86,7 @@ class FFNExperts(MoeExperts):
self.drop = nn.Dropout(p=drop_rate) self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters(): for param in self.parameters():
param.__setattr__('moe_param', True) param.__setattr__('moe_info', self.dist_info)
def forward(self, inputs): # inputs [g, el, c, h] def forward(self, inputs): # inputs [g, el, c, h]
@ -99,9 +100,9 @@ class FFNExperts(MoeExperts):
out_ff = torch.baddbmm(self.b1, inputs, self.w1) out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff) out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
inter = self.drop(out_act) out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, inter, self.w2) out_model = torch.baddbmm(self.b2, out_inter, self.w2)
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h] outputs = self.drop(out_model) # outputs [el, gc, h]
@ -111,14 +112,18 @@ class FFNExperts(MoeExperts):
class TPExperts(MoeExperts): class TPExperts(MoeExperts):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divied by maximum expert parallel size or
maximum expert parallel size can't be divied by the number of experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather") super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
assert d_ff % moe_env.model_parallel_size == 0, \ assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
"d_ff should be divied by moe model size" "d_ff should be divied by maximum expert parallel size"
p_ff = d_ff // moe_env.model_parallel_size p_ff = d_ff // MOE_CONTEXT.max_ep_size
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
@ -129,7 +134,7 @@ class TPExperts(MoeExperts):
s1 = math.sqrt(0.1 / d_model) s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff) s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.MOE_MODEL): with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1) nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1) nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2) nn.init.trunc_normal_(self.w2, std=s2)
@ -139,9 +144,9 @@ class TPExperts(MoeExperts):
self.act = nn.GELU() if activation is None else activation self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate) self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__('moe_param', True) self.w1.__setattr__('moe_info', self.dist_info)
self.w2.__setattr__('moe_param', True) self.w2.__setattr__('moe_info', self.dist_info)
self.b1.__setattr__('moe_param', True) self.b1.__setattr__('moe_info', self.dist_info)
def forward(self, inputs): # inputs [g, e, c, h] def forward(self, inputs): # inputs [g, e, c, h]
@ -155,9 +160,9 @@ class TPExperts(MoeExperts):
out_ff = torch.baddbmm(self.b1, inputs, self.w1) out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff) out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
inter = self.drop(out_act) out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, inter, self.w2) out_model = torch.baddbmm(self.b2, out_inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h] outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape) outputs = outputs.reshape(inshape)

@ -4,14 +4,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc from colossalai.core import MOE_CONTEXT
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts from .experts import MoeExperts
from .utils import autocast_softmax from .utils import autocast_softmax
from typing import Callable from typing import Callable, Optional
from torch.distributed import ProcessGroup
class Top1Router(nn.Module): class Top1Router(nn.Module):
@ -19,8 +18,8 @@ class Top1Router(nn.Module):
for routing usage. More deailted function can be found in the paper about Switch Transformer for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google. of Google.
:param capacity_factor_train: Capacity factor in routing of training :param capacity_factor_train: Capacity factor in routing during training
:param capacity_factor_eval: Capacity factor in routing of evaluation :param capacity_factor_eval: Capacity factor in routing during evaluation
:param min_capacity: The minimum number of the capacity of each expert :param min_capacity: The minimum number of the capacity of each expert
:param select_policy: The policy about tokens selection :param select_policy: The policy about tokens selection
:param noisy_func: Noisy function used in logits :param noisy_func: Noisy function used in logits
@ -66,7 +65,7 @@ class Top1Router(nn.Module):
assert capacity > 0 assert capacity > 0
return capacity return capacity
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
if self.noisy_func is not None and self.training: if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs) inputs = self.noisy_func(inputs)
@ -82,10 +81,10 @@ class Top1Router(nn.Module):
me = torch.mean(logits, dim=0) me = torch.mean(logits, dim=0)
ce = torch.mean(mask.float(), dim=0) ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) l_aux = num_experts * torch.sum(me * ce)
moe_env.add_loss(l_aux) MOE_CONTEXT.add_loss(l_aux)
elif not self.drop_tks: elif not self.drop_tks:
max_num = torch.max(torch.sum(mask, dim=0)) max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item() capacity = max_num.item()
else: else:
pass pass
@ -103,7 +102,7 @@ class Top1Router(nn.Module):
ranks = torch.sum(mask * ranks, dim=-1) ranks = torch.sum(mask * ranks, dim=-1)
if cuda_mode: if use_kernel:
mask = torch.sum(mask, dim=-1) mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32) mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
@ -120,8 +119,8 @@ class Top2Router(nn.Module):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE. for routing usage. More deailted function can be found in the paper about ViT-MoE.
:param capacity_factor_train: Capacity factor in routing of training :param capacity_factor_train: Capacity factor in routing during training
:param capacity_factor_eval: Capacity factor in routing of evaluation :param capacity_factor_eval: Capacity factor in routing during evaluation
:param min_capacity: The minimum number of the capacity of each expert :param min_capacity: The minimum number of the capacity of each expert
:param noisy_func: Noisy function used in logits :param noisy_func: Noisy function used in logits
:param drop_tks: Whether drops tokens in evaluation :param drop_tks: Whether drops tokens in evaluation
@ -157,7 +156,7 @@ class Top2Router(nn.Module):
assert capacity > 0 assert capacity > 0
return capacity return capacity
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
# inputs: [s, h] # inputs: [s, h]
if self.noisy_func is not None and self.training: if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs) inputs = self.noisy_func(inputs)
@ -177,10 +176,10 @@ class Top2Router(nn.Module):
me = torch.mean(logits, dim=0) me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0) ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
moe_env.add_loss(l_aux) MOE_CONTEXT.add_loss(l_aux)
elif not self.drop_tks: elif not self.drop_tks:
max_num = torch.max(torch.sum(cmask, dim=0)) max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item() capacity = max_num.item()
else: else:
pass pass
@ -195,7 +194,7 @@ class Top2Router(nn.Module):
rank1 = torch.sum(mask1 * rank1, dim=-1) rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1) rank2 = torch.sum(mask2 * rank2, dim=-1)
if cuda_mode: if use_kernel:
mask1 = torch.sum(mask1, dim=-1) mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1) mask2 = torch.sum(mask2, dim=-1)
@ -241,34 +240,36 @@ class MoeLayer(nn.Module):
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device()) self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
self.router = router self.router = router
self.experts = experts self.experts = experts
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
self.ep_group = experts.dist_info.ep_group
self.ep_size = experts.dist_info.ep_size
self.num_local_experts = experts.num_local_experts
def a2a_process(self, dispatch_data: torch.Tensor): def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL) expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape input_shape = expert_input.shape
expert_input = expert_input.reshape(moe_env.model_parallel_size, expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
self.num_experts // moe_env.model_parallel_size, -1, self.d_model)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape) expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL) expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output return expert_output
def tp_process(self, dispatch_data: torch.Tensor): def tp_process(self, dispatch_data: torch.Tensor):
expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL) expert_in = AllGather.apply(dispatch_data, self.ep_group)
expert_out = self.experts(expert_in) expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL) expert_out = ReduceScatter.apply(expert_out, self.ep_group)
return expert_out return expert_out
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model) tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens) gate_output = self.gate(tokens)
router_res = self.router(gate_output, self.cuda_mode) router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.cuda_mode: if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:]) dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else: else:
@ -276,16 +277,16 @@ class MoeLayer(nn.Module):
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h] # dispatch_data [e, c, h]
if self.experts.comm == "all_to_all": if self.experts.comm_name == "all_to_all":
expert_output = self.a2a_process(dispatch_data) expert_output = self.a2a_process(dispatch_data)
elif self.experts.comm == "all_gather": elif self.experts.comm_name == "all_gather":
expert_output = self.tp_process(dispatch_data) expert_output = self.tp_process(dispatch_data)
else: else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.") "build function.")
# expert_output [e, c, h] # expert_output [e, c, h]
if self.cuda_mode: if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model) expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res) ans = MoeCombine.apply(expert_output, *router_res)
else: else:

@ -1,7 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.global_variables import moe_env from colossalai.core import MOE_CONTEXT
from .experts import FFNExperts, TPExperts from .experts import FFNExperts, TPExperts
@ -36,7 +36,7 @@ class UniformNoiseGenerator:
:type eps: float :type eps: float
""" """
def __init__(self, eps: float): def __init__(self, eps: float = 1e-2):
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
high=torch.tensor(1.0 + eps, high=torch.tensor(1.0 + eps,
device=get_current_device())).rsample device=get_current_device())).rsample
@ -55,10 +55,10 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
moe_mp_size = moe_env.model_parallel_size mep_size = MOE_CONTEXT.max_ep_size
if num_experts % moe_mp_size == 0: if num_experts % mep_size == 0 or mep_size % num_experts == 0:
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % moe_mp_size == 0: elif d_ff % mep_size == 0:
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
else: else:
raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.") raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")

@ -1,7 +1,7 @@
import torch.nn as nn import torch.nn as nn
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.global_variables import moe_env from colossalai.core import MOE_CONTEXT
@LOSSES.register_module @LOSSES.register_module
@ -14,6 +14,7 @@ class MoeCrossEntropyLoss(_Loss):
:type aux_weight: float, optional :type aux_weight: float, optional
""" """
def __init__(self, aux_weight: float = 0.01, *args, **kwargs): def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
super().__init__() super().__init__()
self.loss = nn.CrossEntropyLoss(*args, **kwargs) self.loss = nn.CrossEntropyLoss(*args, **kwargs)
@ -21,7 +22,7 @@ class MoeCrossEntropyLoss(_Loss):
def forward(self, *args): def forward(self, *args):
main_loss = self.loss(*args) main_loss = self.loss(*args)
aux_loss = moe_env.get_loss() aux_loss = MOE_CONTEXT.get_loss()
return main_loss + self.aux_weight * aux_loss return main_loss + self.aux_weight * aux_loss
@ -37,6 +38,7 @@ class MoeLoss(_Loss):
:type aux_weight: float :type aux_weight: float
:type loss_fn: Callable :type loss_fn: Callable
""" """
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
super().__init__() super().__init__()
self.loss_fn = loss_fn(*args, **kwargs) self.loss_fn = loss_fn(*args, **kwargs)
@ -44,5 +46,5 @@ class MoeLoss(_Loss):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
main_loss = self.loss_fn(*args, **kwargs) main_loss = self.loss_fn(*args, **kwargs)
aux_loss = moe_env.get_loss() aux_loss = MOE_CONTEXT.get_loss()
return main_loss + self.aux_weight * aux_loss return main_loss + self.aux_weight * aux_loss

@ -1,6 +1,6 @@
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc, moe_context as moe_env from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from .common import is_using_ddp from .common import is_using_ddp
from typing import Dict, List from typing import Dict, List
@ -45,7 +45,7 @@ def sync_moe_model_param(model: nn.Module):
for ep_size in param_dict: for ep_size in param_dict:
# When ep_size = world_size, communication is not needed # When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != moe_env.world_size: if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
src_rank = dist.get_rank(moe_env.information[ep_size].ep_group) src_rank = dist.get_rank(MOE_CONTEXT.information[ep_size].ep_group)
for param in param_dict[ep_size]: for param in param_dict[ep_size]:
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)

Loading…
Cancel
Save