[moe] add checkpoint for moe models (#3354)

* [moe] add checkpoint for moe models

* [hotfix] fix bugs in unit test
pull/3367/head
HELSON 2023-03-31 09:20:33 +08:00 committed by GitHub
parent fee2af8610
commit 1a1d68b053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 517 additions and 384 deletions

View File

@ -1,9 +1,10 @@
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, MoeModule
from .routers import MoeRouter, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter'
]
from .checkpoint import load_moe_model, save_moe_model
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, MoeModule
from .routers import MoeRouter, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model'
]

View File

@ -0,0 +1,40 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from .experts import MoeExperts
def save_moe_model(model: nn.Module, save_path: str):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, save_path)
dist.barrier()
def load_moe_model(model: nn.Module, load_path: str):
state_dict = torch.load(load_path)
for prefix, module in model.named_modules():
if prefix.endswith('.moe_layer.experts'):
# this module should be an Experts instance
assert isinstance(module, MoeExperts)
ep_rank = dist.get_rank(module.dist_info.ep_group)
num_local = module.num_local_experts
for i in range(num_local):
expert_id = ep_rank * num_local + i
for name, _ in module.experts[i].named_parameters():
cur_key = f'{prefix}.experts.{i}.{name}'
param_key = f'{prefix}.experts.{expert_id}.{name}'
load_param = state_dict[param_key]
state_dict[cur_key] = load_param
for name, _ in module.experts[0].named_parameters():
pop_pre = f'{prefix}.experts.'
pop_suf = f'.{name}'
for i in range(num_local, module.num_total_experts):
pop_key = f'{pop_pre}{i}{pop_suf}'
state_dict.pop(pop_key)
model.load_state_dict(state_dict)

View File

@ -1,172 +1,203 @@
import math
import torch
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.zero.init_ctx import no_shard_zero_decrator
from typing import Type
class MoeExperts(nn.Module):
"""Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups.
"""
def __init__(self, comm_name: str, num_experts: int):
super().__init__()
assert comm_name in {"all_to_all", "all_gather"}, \
"This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm_name = comm_name
# Get the configuration of experts' deployment and parallel information from moe contex
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
@no_shard_zero_decrator(is_replicated=False)
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
Args:
expert_cls (:class:`torch.nn.Module`): The class of all experts
num_experts (int): The number of experts
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts)
# Use seed to make every expert different from others
with seed(ParallelMode.TENSOR):
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
# Attach parallel information for all parameters in Experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_info', self.dist_info)
def forward(self, inputs: torch.Tensor):
# Split inputs for each expert
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
# Get outputs from each expert
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
# Concatenate all outputs together
output = torch.cat(expert_output, dim=1).contiguous()
return output
class FFNExperts(MoeExperts):
"""Use torch.bmm to speed up for multiple experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_to_all", num_experts)
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters():
param.__setattr__('moe_info', self.dist_info)
def forward(self, inputs): # inputs [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs
class TPExperts(MoeExperts):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divied by maximum expert parallel size or
maximum expert parallel size can't be divied by the number of experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
"d_ff should be divied by maximum expert parallel size"
p_ff = d_ff // MOE_CONTEXT.max_ep_size
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__('moe_info', self.dist_info)
self.w2.__setattr__('moe_info', self.dist_info)
self.b1.__setattr__('moe_info', self.dist_info)
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]
import math
from copy import deepcopy
from typing import Type
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_decrator
class MoeExperts(nn.Module):
"""Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups.
"""
def __init__(self, comm_name: str, num_experts: int):
super().__init__()
assert comm_name in {"all_to_all", "all_gather"}, \
"This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm_name = comm_name
self.num_total_experts = num_experts
# Get the configuration of experts' deployment and parallel information from moe contex
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
@no_shard_zero_decrator(is_replicated=False)
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
Args:
expert_cls (:class:`torch.nn.Module`): The class of all experts
num_experts (int): The number of experts
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts)
# Use seed to make every expert different from others
with seed(ParallelMode.TENSOR):
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
# Attach parallel information for all parameters in Experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_info', self.dist_info)
def forward(self, inputs: torch.Tensor):
# Split inputs for each expert
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
# Get outputs from each expert
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
# Concatenate all outputs together
output = torch.cat(expert_output, dim=1).contiguous()
return output
def state_dict(self, destination=None, prefix='', keep_vars=False):
assert keep_vars == False, "Only support keep_vars=False now"
dp_rank = dist.get_rank(self.dist_info.dp_group)
ep_rank = dist.get_rank(self.dist_info.ep_group)
submodule_dict = dict()
example_submodule = None
for name, subm in self.experts.named_modules():
if subm is self.experts:
continue
module_number = self.num_local_experts * ep_rank + int(name)
submodule_dict[module_number] = subm
example_submodule = subm
if dp_rank == 0:
local_prefix = prefix + 'experts.'
buffer_module = deepcopy(example_submodule)
for i in range(self.num_total_experts):
source_rank = i // self.num_local_experts
current_prefix = local_prefix + str(i) + '.'
comm_module = submodule_dict.get(i, buffer_module)
for name, param in comm_module.named_parameters():
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
if ep_rank == 0:
destination[current_prefix + name] = param.data.cpu()
dist.barrier()
class FFNExperts(MoeExperts):
"""Use torch.bmm to speed up for multiple experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_to_all", num_experts)
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters():
param.__setattr__('moe_info', self.dist_info)
def forward(self, inputs): # inputs [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs
class TPExperts(MoeExperts):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divied by maximum expert parallel size or
maximum expert parallel size can't be divied by the number of experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
"d_ff should be divied by maximum expert parallel size"
p_ff = d_ff // MOE_CONTEXT.max_ep_size
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__('moe_info', self.dist_info)
self.w2.__setattr__('moe_info', self.dist_info)
self.b1.__setattr__('moe_info', self.dist_info)
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]

View File

@ -1,203 +1,210 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \
ReduceScatter, MoeDispatch, MoeCombine
from colossalai.nn.layer.moe.experts import MoeExperts, Experts
from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from typing import Optional, Type, Tuple
@no_shard_zero_decrator(is_replicated=True)
class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
the moe tensor group by all to all comunication. Then it will get the output of all
experts and exchange the output. At last returns the output of the moe system.
Args:
dim_model (int): Dimension of model.
num_experts (int): The number of experts.
router (MoeRouter): Instance of router used in routing.
experts (MoeExperts): Instance of experts generated by Expert.
"""
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
self.router: MoeRouter = router
self.experts: MoeExperts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
self.ep_group = experts.dist_info.ep_group
self.ep_size = experts.dist_info.ep_size
self.num_local_experts = experts.num_local_experts
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output
def tp_process(self, dispatch_data: torch.Tensor):
expert_in = AllGather.apply(dispatch_data, self.ep_group)
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
return expert_out
def forward(self, inputs: torch.Tensor) -> Tuple:
# reshape the input tokens
tokens = inputs.reshape(-1, self.d_model)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else:
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h]
if self.experts.comm_name == "all_to_all":
expert_output = self.a2a_process(dispatch_data)
elif self.experts.comm_name == "all_gather":
expert_output = self.tp_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
# expert_output [e, c, h]
if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *route_result_list)
else:
combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
l_aux = self.router.pop_routing_loss()
return ans, l_aux
class MoeModule(nn.Module):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in `Switch Transformer paper`_.
'Gaussian' can be found in `ViT-MoE paper`_.
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in `Microsoft paper`_.
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
.. _Switch Transformer paper:
https://arxiv.org/abs/2101.03961
.. _ViT-MoE paper:
https://arxiv.org/abs/2106.05974
.. _Microsoft paper:
https://arxiv.org/abs/2201.05596
"""
def __init__(self,
dim_model: int,
num_experts: int,
top_k: int = 1,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_policy: Optional[str] = None,
drop_tks: bool = True,
use_residual: bool = False,
residual_instance: Optional[nn.Module] = None,
expert_instance: Optional[MoeExperts] = None,
expert_cls: Optional[Type[nn.Module]] = None,
**expert_args):
super().__init__()
noisy_func = None
if noisy_policy is not None:
if noisy_policy == 'Jitter':
noisy_func = UniformNoiseGenerator()
elif noisy_policy == 'Gaussian':
noisy_func = NormalNoiseGenerator(num_experts)
else:
raise NotImplementedError("Unsupported input noisy policy")
if top_k == 1:
moe_router_cls = Top1Router
elif top_k == 2:
moe_router_cls = Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.use_residual = use_residual
if use_residual:
if residual_instance is not None:
self.residual_module = residual_instance
else:
assert expert_cls is not None, \
"Expert class can't be None when residual instance is not given"
self.residual_module = expert_cls(**expert_args)
with no_shard_zero_context():
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None:
self.experts = expert_instance
else:
assert expert_cls is not None, \
"Expert class can't be None when experts instance is not given"
self.experts = Experts(expert_cls, num_experts, **expert_args)
self.moe_layer = MoeLayer(dim_model=dim_model,
num_experts=num_experts,
router=self.moe_router,
experts=self.experts)
def forward(self, inputs: torch.Tensor):
moe_output, l_aux = self.moe_layer(inputs)
if self.use_residual:
residual_output = self.residual_module(inputs)
combine_coef = self.residual_combine(inputs)
combine_coef = F.softmax(combine_coef, dim=-1)
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
else:
output = moe_output
return output, l_aux
import math
from typing import Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.nn.layer.moe._operation import (
COL_MOE_KERNEL_FLAG,
AllGather,
AllToAll,
MoeCombine,
MoeDispatch,
ReduceScatter,
)
from colossalai.nn.layer.moe.experts import Experts, MoeExperts
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
@no_shard_zero_decrator(is_replicated=True)
class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
the moe tensor group by all to all comunication. Then it will get the output of all
experts and exchange the output. At last returns the output of the moe system.
Args:
dim_model (int): Dimension of model.
num_experts (int): The number of experts.
router (MoeRouter): Instance of router used in routing.
experts (MoeExperts): Instance of experts generated by Expert.
"""
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
self.router: MoeRouter = router
self.experts: MoeExperts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
self.ep_group = experts.dist_info.ep_group
self.ep_size = experts.dist_info.ep_size
self.num_local_experts = experts.num_local_experts
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output
def tp_process(self, dispatch_data: torch.Tensor):
expert_in = AllGather.apply(dispatch_data, self.ep_group)
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
return expert_out
def forward(self, inputs: torch.Tensor) -> Tuple:
# reshape the input tokens
tokens = inputs.reshape(-1, self.d_model)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else:
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h]
if self.experts.comm_name == "all_to_all":
expert_output = self.a2a_process(dispatch_data)
elif self.experts.comm_name == "all_gather":
expert_output = self.tp_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
# expert_output [e, c, h]
if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *route_result_list)
else:
combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
l_aux = self.router.pop_routing_loss()
return ans, l_aux
class MoeModule(nn.Module):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in `Switch Transformer paper`_.
'Gaussian' can be found in `ViT-MoE paper`_.
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in `Microsoft paper`_.
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
.. _Switch Transformer paper:
https://arxiv.org/abs/2101.03961
.. _ViT-MoE paper:
https://arxiv.org/abs/2106.05974
.. _Microsoft paper:
https://arxiv.org/abs/2201.05596
"""
def __init__(self,
dim_model: int,
num_experts: int,
top_k: int = 1,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_policy: Optional[str] = None,
drop_tks: bool = True,
use_residual: bool = False,
residual_instance: Optional[nn.Module] = None,
expert_instance: Optional[MoeExperts] = None,
expert_cls: Optional[Type[nn.Module]] = None,
**expert_args):
super().__init__()
noisy_func = None
if noisy_policy is not None:
if noisy_policy == 'Jitter':
noisy_func = UniformNoiseGenerator()
elif noisy_policy == 'Gaussian':
noisy_func = NormalNoiseGenerator(num_experts)
else:
raise NotImplementedError("Unsupported input noisy policy")
if top_k == 1:
moe_router_cls = Top1Router
elif top_k == 2:
moe_router_cls = Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.use_residual = use_residual
if use_residual:
if residual_instance is not None:
self.residual_module = residual_instance
else:
assert expert_cls is not None, \
"Expert class can't be None when residual instance is not given"
self.residual_module = expert_cls(**expert_args)
with no_shard_zero_context():
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None:
my_experts = expert_instance
else:
assert expert_cls is not None, \
"Expert class can't be None when experts instance is not given"
my_experts = Experts(expert_cls, num_experts, **expert_args)
self.moe_layer = MoeLayer(dim_model=dim_model,
num_experts=num_experts,
router=self.moe_router,
experts=my_experts)
def forward(self, inputs: torch.Tensor):
moe_output, l_aux = self.moe_layer(inputs)
if self.use_residual:
residual_output = self.residual_module(inputs)
combine_coef = self.residual_combine(inputs)
combine_coef = F.softmax(combine_coef, dim=-1)
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
else:
output = moe_output
return output, l_aux

View File

@ -0,0 +1,54 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer.moe import load_moe_model, save_moe_model
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_tensor.common_utils import debug_print
from tests.test_zero.common import CONFIG
def exam_moe_checkpoint():
with ColoInitContext(device=get_current_device()):
model = MoeModel(checkpoint=True)
save_moe_model(model, 'temp_path.pth')
with ColoInitContext(device=get_current_device()):
other_model = MoeModel(checkpoint=True)
load_moe_model(other_model, 'temp_path.pth')
state_0 = model.state_dict()
state_1 = other_model.state_dict()
for k, v in state_0.items():
u = state_1.get(k)
assert torch.equal(u.data, v.data)
if dist.get_rank() == 0:
os.remove('temp_path.pth')
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
exam_moe_checkpoint()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_checkpoint(world_size=4)