mirror of https://github.com/hpcaitech/ColossalAI
[MOE] polish moe_env (#467)
parent
bccbc15861
commit
aff9d354f7
|
@ -4,4 +4,4 @@
|
|||
from colossalai.context import ParallelContext, MoeContext
|
||||
|
||||
global_context = ParallelContext.get_instance()
|
||||
moe_context = MoeContext.get_instance()
|
||||
MOE_CONTEXT = MoeContext.get_instance()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from colossalai.core import global_context as gpc, moe_context as moe_env
|
||||
from colossalai.core import global_context as gpc, MOE_CONTEXT
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.utils.moe import get_moe_epsize_param_dict
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
@ -30,5 +30,5 @@ class MoeGradientHandler(BaseGradientHandler):
|
|||
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
for ep_size in param_dict:
|
||||
if ep_size != 1 and ep_size != moe_env.world_size:
|
||||
bucket_allreduce(param_list=param_dict[ep_size], group=moe_env.information[ep_size].dp_group)
|
||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||
bucket_allreduce(param_list=param_dict[ep_size], group=MOE_CONTEXT.information[ep_size].dp_group)
|
||||
|
|
|
@ -4,11 +4,11 @@ from torch import Tensor
|
|||
from typing import Any, Tuple, Optional
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
U_CUDA_MODE = False
|
||||
COL_MOE_KERNEL_FLAG = False
|
||||
try:
|
||||
import colossal_moe_cuda
|
||||
|
||||
U_CUDA_MODE = True
|
||||
COL_MOE_KERNEL_FLAG = True
|
||||
except ImportError:
|
||||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
||||
|
||||
|
@ -17,7 +17,6 @@ class AllGather(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
|
@ -40,7 +39,6 @@ class ReduceScatter(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
|
@ -149,7 +147,7 @@ class MoeCombine(torch.autograd.Function):
|
|||
def moe_cumsum(inputs: Tensor):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and U_CUDA_MODE:
|
||||
if flag and COL_MOE_KERNEL_FLAG:
|
||||
return colossal_moe_cuda.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
|
|
@ -2,18 +2,24 @@ import math
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
|
||||
|
||||
class MoeExperts(nn.Module):
|
||||
"""Basic class for experts in MoE. It stores what kind of communication expersts use
|
||||
to exchange tokens, how many experts in a single GPU and parallel information such as
|
||||
expert parallel size, data parallel size and their distributed communication groups.
|
||||
"""
|
||||
|
||||
def __init__(self, comm: str):
|
||||
def __init__(self, comm_name: str, num_experts: int):
|
||||
super().__init__()
|
||||
assert comm in {"all_to_all", "all_gather"}, \
|
||||
assert comm_name in {"all_to_all", "all_gather"}, \
|
||||
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
||||
self.comm = comm
|
||||
self.comm_name = comm_name
|
||||
# Get the configuration of experts' deployment and parallel information from moe contex
|
||||
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
||||
|
||||
|
||||
class Experts(MoeExperts):
|
||||
|
@ -29,53 +35,48 @@ class Experts(MoeExperts):
|
|||
"""
|
||||
|
||||
def __init__(self, expert, num_experts, **expert_args):
|
||||
super().__init__("all_to_all")
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
||||
"The number of experts should be divied by moe model size"
|
||||
|
||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
||||
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
|
||||
# Use seed to make every expert different from others
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(self.num_local_experts)])
|
||||
|
||||
# Attach parallel information for all parameters in Experts
|
||||
for exp in self.experts:
|
||||
for param in exp.parameters():
|
||||
param.__setattr__('moe_param', True)
|
||||
param.__setattr__('moe_info', self.dist_info)
|
||||
|
||||
self.num_local_experts = num_local_experts
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
# Split inputs for each expert
|
||||
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
||||
expert_output = []
|
||||
|
||||
# Get outputs from each expert
|
||||
for i in range(self.num_local_experts):
|
||||
expert_output.append(self.experts[i](expert_input[i]))
|
||||
|
||||
# Concatenate all outputs together
|
||||
output = torch.cat(expert_output, dim=1).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
class FFNExperts(MoeExperts):
|
||||
"""Use torch.bmm to speed up for multiple experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__("all_to_all")
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
||||
"The number of experts should be divied by moe model size"
|
||||
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
|
||||
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
|
||||
|
||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
||||
|
||||
self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device()))
|
||||
self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device()))
|
||||
|
||||
self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device()))
|
||||
self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device()))
|
||||
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
|
||||
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
|
||||
|
||||
s1 = math.sqrt(0.1 / d_model)
|
||||
s2 = math.sqrt(0.1 / d_ff)
|
||||
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.trunc_normal_(self.w1, std=s1)
|
||||
nn.init.trunc_normal_(self.b1, std=s1)
|
||||
nn.init.trunc_normal_(self.w2, std=s2)
|
||||
|
@ -85,7 +86,7 @@ class FFNExperts(MoeExperts):
|
|||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
for param in self.parameters():
|
||||
param.__setattr__('moe_param', True)
|
||||
param.__setattr__('moe_info', self.dist_info)
|
||||
|
||||
def forward(self, inputs): # inputs [g, el, c, h]
|
||||
|
||||
|
@ -99,9 +100,9 @@ class FFNExperts(MoeExperts):
|
|||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||
out_act = self.act(out_ff)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
inter = self.drop(out_act)
|
||||
out_inter = self.drop(out_act)
|
||||
|
||||
out_model = torch.baddbmm(self.b2, inter, self.w2)
|
||||
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
outputs = self.drop(out_model) # outputs [el, gc, h]
|
||||
|
||||
|
@ -111,14 +112,18 @@ class FFNExperts(MoeExperts):
|
|||
|
||||
|
||||
class TPExperts(MoeExperts):
|
||||
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
|
||||
case that the number of experts can't be divied by maximum expert parallel size or
|
||||
maximum expert parallel size can't be divied by the number of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__("all_gather")
|
||||
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
|
||||
|
||||
assert d_ff % moe_env.model_parallel_size == 0, \
|
||||
"d_ff should be divied by moe model size"
|
||||
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
|
||||
"d_ff should be divied by maximum expert parallel size"
|
||||
|
||||
p_ff = d_ff // moe_env.model_parallel_size
|
||||
p_ff = d_ff // MOE_CONTEXT.max_ep_size
|
||||
|
||||
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
|
||||
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
|
||||
|
@ -129,7 +134,7 @@ class TPExperts(MoeExperts):
|
|||
s1 = math.sqrt(0.1 / d_model)
|
||||
s2 = math.sqrt(0.1 / d_ff)
|
||||
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.trunc_normal_(self.w1, std=s1)
|
||||
nn.init.trunc_normal_(self.b1, std=s1)
|
||||
nn.init.trunc_normal_(self.w2, std=s2)
|
||||
|
@ -139,9 +144,9 @@ class TPExperts(MoeExperts):
|
|||
self.act = nn.GELU() if activation is None else activation
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
self.w1.__setattr__('moe_param', True)
|
||||
self.w2.__setattr__('moe_param', True)
|
||||
self.b1.__setattr__('moe_param', True)
|
||||
self.w1.__setattr__('moe_info', self.dist_info)
|
||||
self.w2.__setattr__('moe_info', self.dist_info)
|
||||
self.b1.__setattr__('moe_info', self.dist_info)
|
||||
|
||||
def forward(self, inputs): # inputs [g, e, c, h]
|
||||
|
||||
|
@ -155,9 +160,9 @@ class TPExperts(MoeExperts):
|
|||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||
out_act = self.act(out_ff)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
inter = self.drop(out_act)
|
||||
out_inter = self.drop(out_act)
|
||||
|
||||
out_model = torch.baddbmm(self.b2, inter, self.w2)
|
||||
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||
outputs = self.drop(out_model) # outputs [e, gc, h]
|
||||
|
||||
outputs = outputs.reshape(inshape)
|
||||
|
|
|
@ -4,14 +4,13 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts
|
||||
from .utils import autocast_softmax
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Top1Router(nn.Module):
|
||||
|
@ -19,8 +18,8 @@ class Top1Router(nn.Module):
|
|||
for routing usage. More deailted function can be found in the paper about Switch Transformer
|
||||
of Google.
|
||||
|
||||
:param capacity_factor_train: Capacity factor in routing of training
|
||||
:param capacity_factor_eval: Capacity factor in routing of evaluation
|
||||
:param capacity_factor_train: Capacity factor in routing during training
|
||||
:param capacity_factor_eval: Capacity factor in routing during evaluation
|
||||
:param min_capacity: The minimum number of the capacity of each expert
|
||||
:param select_policy: The policy about tokens selection
|
||||
:param noisy_func: Noisy function used in logits
|
||||
|
@ -66,7 +65,7 @@ class Top1Router(nn.Module):
|
|||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
@ -82,10 +81,10 @@ class Top1Router(nn.Module):
|
|||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce)
|
||||
moe_env.add_loss(l_aux)
|
||||
MOE_CONTEXT.add_loss(l_aux)
|
||||
elif not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
else:
|
||||
pass
|
||||
|
@ -103,7 +102,7 @@ class Top1Router(nn.Module):
|
|||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
|
||||
if cuda_mode:
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
|
@ -120,8 +119,8 @@ class Top2Router(nn.Module):
|
|||
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More deailted function can be found in the paper about ViT-MoE.
|
||||
|
||||
:param capacity_factor_train: Capacity factor in routing of training
|
||||
:param capacity_factor_eval: Capacity factor in routing of evaluation
|
||||
:param capacity_factor_train: Capacity factor in routing during training
|
||||
:param capacity_factor_eval: Capacity factor in routing during evaluation
|
||||
:param min_capacity: The minimum number of the capacity of each expert
|
||||
:param noisy_func: Noisy function used in logits
|
||||
:param drop_tks: Whether drops tokens in evaluation
|
||||
|
@ -157,7 +156,7 @@ class Top2Router(nn.Module):
|
|||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
# inputs: [s, h]
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
@ -177,10 +176,10 @@ class Top2Router(nn.Module):
|
|||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(cmask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
||||
moe_env.add_loss(l_aux)
|
||||
MOE_CONTEXT.add_loss(l_aux)
|
||||
elif not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
else:
|
||||
pass
|
||||
|
@ -195,7 +194,7 @@ class Top2Router(nn.Module):
|
|||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if cuda_mode:
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
|
@ -241,34 +240,36 @@ class MoeLayer(nn.Module):
|
|||
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
|
||||
self.router = router
|
||||
self.experts = experts
|
||||
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
|
||||
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
||||
self.ep_group = experts.dist_info.ep_group
|
||||
self.ep_size = experts.dist_info.ep_size
|
||||
self.num_local_experts = experts.num_local_experts
|
||||
|
||||
def a2a_process(self, dispatch_data: torch.Tensor):
|
||||
expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
|
||||
|
||||
input_shape = expert_input.shape
|
||||
|
||||
expert_input = expert_input.reshape(moe_env.model_parallel_size,
|
||||
self.num_experts // moe_env.model_parallel_size, -1, self.d_model)
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
|
||||
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = expert_output.reshape(input_shape)
|
||||
|
||||
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
|
||||
expert_output = AllToAll.apply(expert_output, self.ep_group)
|
||||
return expert_output
|
||||
|
||||
def tp_process(self, dispatch_data: torch.Tensor):
|
||||
expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL)
|
||||
expert_in = AllGather.apply(dispatch_data, self.ep_group)
|
||||
expert_out = self.experts(expert_in)
|
||||
expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL)
|
||||
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
|
||||
return expert_out
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
tokens = inputs.reshape(-1, self.d_model)
|
||||
gate_output = self.gate(tokens)
|
||||
router_res = self.router(gate_output, self.cuda_mode)
|
||||
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
||||
|
||||
if self.cuda_mode:
|
||||
if self.use_kernel:
|
||||
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
|
||||
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
||||
else:
|
||||
|
@ -276,16 +277,16 @@ class MoeLayer(nn.Module):
|
|||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
# dispatch_data [e, c, h]
|
||||
if self.experts.comm == "all_to_all":
|
||||
if self.experts.comm_name == "all_to_all":
|
||||
expert_output = self.a2a_process(dispatch_data)
|
||||
elif self.experts.comm == "all_gather":
|
||||
elif self.experts.comm_name == "all_gather":
|
||||
expert_output = self.tp_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
||||
"build function.")
|
||||
# expert_output [e, c, h]
|
||||
|
||||
if self.cuda_mode:
|
||||
if self.use_kernel:
|
||||
expert_output = expert_output.reshape(-1, self.d_model)
|
||||
ans = MoeCombine.apply(expert_output, *router_res)
|
||||
else:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from .experts import FFNExperts, TPExperts
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ class UniformNoiseGenerator:
|
|||
:type eps: float
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float):
|
||||
def __init__(self, eps: float = 1e-2):
|
||||
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
|
||||
high=torch.tensor(1.0 + eps,
|
||||
device=get_current_device())).rsample
|
||||
|
@ -55,10 +55,10 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
|
|||
|
||||
|
||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
moe_mp_size = moe_env.model_parallel_size
|
||||
if num_experts % moe_mp_size == 0:
|
||||
mep_size = MOE_CONTEXT.max_ep_size
|
||||
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
||||
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
elif d_ff % moe_mp_size == 0:
|
||||
elif d_ff % mep_size == 0:
|
||||
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
else:
|
||||
raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.")
|
||||
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch.nn as nn
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
|
@ -14,6 +14,7 @@ class MoeCrossEntropyLoss(_Loss):
|
|||
|
||||
:type aux_weight: float, optional
|
||||
"""
|
||||
|
||||
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
|
||||
|
@ -21,7 +22,7 @@ class MoeCrossEntropyLoss(_Loss):
|
|||
|
||||
def forward(self, *args):
|
||||
main_loss = self.loss(*args)
|
||||
aux_loss = moe_env.get_loss()
|
||||
aux_loss = MOE_CONTEXT.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
||||
|
||||
|
||||
|
@ -37,6 +38,7 @@ class MoeLoss(_Loss):
|
|||
:type aux_weight: float
|
||||
:type loss_fn: Callable
|
||||
"""
|
||||
|
||||
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_fn = loss_fn(*args, **kwargs)
|
||||
|
@ -44,5 +46,5 @@ class MoeLoss(_Loss):
|
|||
|
||||
def forward(self, *args, **kwargs):
|
||||
main_loss = self.loss_fn(*args, **kwargs)
|
||||
aux_loss = moe_env.get_loss()
|
||||
aux_loss = MOE_CONTEXT.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc, moe_context as moe_env
|
||||
from colossalai.core import global_context as gpc, MOE_CONTEXT
|
||||
from colossalai.context import ParallelMode
|
||||
from .common import is_using_ddp
|
||||
from typing import Dict, List
|
||||
|
@ -45,7 +45,7 @@ def sync_moe_model_param(model: nn.Module):
|
|||
|
||||
for ep_size in param_dict:
|
||||
# When ep_size = world_size, communication is not needed
|
||||
if ep_size != 1 and ep_size != moe_env.world_size:
|
||||
src_rank = dist.get_rank(moe_env.information[ep_size].ep_group)
|
||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||
src_rank = dist.get_rank(MOE_CONTEXT.information[ep_size].ep_group)
|
||||
for param in param_dict[ep_size]:
|
||||
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
|
||||
|
|
Loading…
Reference in New Issue