[moe] add checkpoint for moe models (#3354)

* [moe] add checkpoint for moe models

* [hotfix] fix bugs in unit test
pull/3367/head
HELSON 2023-03-31 09:20:33 +08:00 committed by GitHub
parent fee2af8610
commit 1a1d68b053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 517 additions and 384 deletions

View File

@ -1,9 +1,10 @@
from .experts import Experts, FFNExperts, TPExperts from .checkpoint import load_moe_model, save_moe_model
from .layers import MoeLayer, MoeModule from .experts import Experts, FFNExperts, TPExperts
from .routers import MoeRouter, Top1Router, Top2Router from .layers import MoeLayer, MoeModule
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts from .routers import MoeRouter, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', __all__ = [
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter' 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
] 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model'
]

View File

@ -0,0 +1,40 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from .experts import MoeExperts
def save_moe_model(model: nn.Module, save_path: str):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, save_path)
dist.barrier()
def load_moe_model(model: nn.Module, load_path: str):
state_dict = torch.load(load_path)
for prefix, module in model.named_modules():
if prefix.endswith('.moe_layer.experts'):
# this module should be an Experts instance
assert isinstance(module, MoeExperts)
ep_rank = dist.get_rank(module.dist_info.ep_group)
num_local = module.num_local_experts
for i in range(num_local):
expert_id = ep_rank * num_local + i
for name, _ in module.experts[i].named_parameters():
cur_key = f'{prefix}.experts.{i}.{name}'
param_key = f'{prefix}.experts.{expert_id}.{name}'
load_param = state_dict[param_key]
state_dict[cur_key] = load_param
for name, _ in module.experts[0].named_parameters():
pop_pre = f'{prefix}.experts.'
pop_suf = f'.{name}'
for i in range(num_local, module.num_total_experts):
pop_key = f'{pop_pre}{i}{pop_suf}'
state_dict.pop(pop_key)
model.load_state_dict(state_dict)

View File

@ -1,172 +1,203 @@
import math import math
from copy import deepcopy
import torch from typing import Type
import torch.nn as nn
from colossalai.context import ParallelMode, seed import torch
from colossalai.utils import get_current_device import torch.distributed as dist
from colossalai.context.moe_context import MOE_CONTEXT import torch.nn as nn
from colossalai.zero.init_ctx import no_shard_zero_decrator
from typing import Type from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
class MoeExperts(nn.Module): from colossalai.zero.init_ctx import no_shard_zero_decrator
"""Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups. class MoeExperts(nn.Module):
""" """Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
def __init__(self, comm_name: str, num_experts: int): expert parallel size, data parallel size and their distributed communication groups.
super().__init__() """
assert comm_name in {"all_to_all", "all_gather"}, \
"This kind of communication has not been implemented yet.\n Please use Experts build function." def __init__(self, comm_name: str, num_experts: int):
self.comm_name = comm_name super().__init__()
# Get the configuration of experts' deployment and parallel information from moe contex assert comm_name in {"all_to_all", "all_gather"}, \
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) "This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm_name = comm_name
self.num_total_experts = num_experts
@no_shard_zero_decrator(is_replicated=False) # Get the configuration of experts' deployment and parallel information from moe contex
class Experts(MoeExperts): self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters. @no_shard_zero_decrator(is_replicated=False)
class Experts(MoeExperts):
Args: """A wrapper class to create experts. It will create E experts across the
expert_cls (:class:`torch.nn.Module`): The class of all experts moe model parallel group, where E is the number of experts. Every expert
num_experts (int): The number of experts is a instence of the class, 'expert' in initialization parameters.
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
""" Args:
expert_cls (:class:`torch.nn.Module`): The class of all experts
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): num_experts (int): The number of experts
super().__init__("all_to_all", num_experts) expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
# Use seed to make every expert different from others
with seed(ParallelMode.TENSOR): def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) super().__init__("all_to_all", num_experts)
# Attach parallel information for all parameters in Experts # Use seed to make every expert different from others
for exp in self.experts: with seed(ParallelMode.TENSOR):
for param in exp.parameters(): self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
param.__setattr__('moe_info', self.dist_info)
# Attach parallel information for all parameters in Experts
def forward(self, inputs: torch.Tensor): for exp in self.experts:
# Split inputs for each expert for param in exp.parameters():
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) param.__setattr__('moe_info', self.dist_info)
expert_output = []
def forward(self, inputs: torch.Tensor):
# Get outputs from each expert # Split inputs for each expert
for i in range(self.num_local_experts): expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output.append(self.experts[i](expert_input[i])) expert_output = []
# Concatenate all outputs together # Get outputs from each expert
output = torch.cat(expert_output, dim=1).contiguous() for i in range(self.num_local_experts):
return output expert_output.append(self.experts[i](expert_input[i]))
# Concatenate all outputs together
class FFNExperts(MoeExperts): output = torch.cat(expert_output, dim=1).contiguous()
"""Use torch.bmm to speed up for multiple experts. return output
"""
def state_dict(self, destination=None, prefix='', keep_vars=False):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): assert keep_vars == False, "Only support keep_vars=False now"
super().__init__("all_to_all", num_experts) dp_rank = dist.get_rank(self.dist_info.dp_group)
ep_rank = dist.get_rank(self.dist_info.ep_group)
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) submodule_dict = dict()
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) example_submodule = None
for name, subm in self.experts.named_modules():
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) if subm is self.experts:
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) continue
module_number = self.num_local_experts * ep_rank + int(name)
s1 = math.sqrt(0.1 / d_model) submodule_dict[module_number] = subm
s2 = math.sqrt(0.1 / d_ff) example_submodule = subm
with seed(ParallelMode.TENSOR): if dp_rank == 0:
nn.init.trunc_normal_(self.w1, std=s1) local_prefix = prefix + 'experts.'
nn.init.trunc_normal_(self.b1, std=s1) buffer_module = deepcopy(example_submodule)
nn.init.trunc_normal_(self.w2, std=s2) for i in range(self.num_total_experts):
nn.init.trunc_normal_(self.b2, std=s2) source_rank = i // self.num_local_experts
current_prefix = local_prefix + str(i) + '.'
self.act = nn.GELU() if activation is None else activation comm_module = submodule_dict.get(i, buffer_module)
self.drop = nn.Dropout(p=drop_rate) for name, param in comm_module.named_parameters():
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
for param in self.parameters(): if ep_rank == 0:
param.__setattr__('moe_info', self.dist_info) destination[current_prefix + name] = param.data.cpu()
def forward(self, inputs): # inputs [g, el, c, h] dist.barrier()
el = inputs.size(1)
h = inputs.size(-1) class FFNExperts(MoeExperts):
"""Use torch.bmm to speed up for multiple experts.
inputs = inputs.transpose(0, 1) """
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h) def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_to_all", num_experts)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff) self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
with seed(ParallelMode.TENSOR): self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
out_inter = self.drop(out_act)
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
out_model = torch.baddbmm(self.b2, out_inter, self.w2) self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h] s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous() with seed(ParallelMode.TENSOR):
return outputs nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
class TPExperts(MoeExperts): nn.init.trunc_normal_(self.b2, std=s2)
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divied by maximum expert parallel size or self.act = nn.GELU() if activation is None else activation
maximum expert parallel size can't be divied by the number of experts. self.drop = nn.Dropout(p=drop_rate)
"""
for param in self.parameters():
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): param.__setattr__('moe_info', self.dist_info)
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
def forward(self, inputs): # inputs [g, el, c, h]
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
"d_ff should be divied by maximum expert parallel size" el = inputs.size(1)
h = inputs.size(-1)
p_ff = d_ff // MOE_CONTEXT.max_ep_size
inputs = inputs.transpose(0, 1)
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) inshape = inputs.shape
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) inputs = inputs.reshape(el, -1, h)
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) out_ff = torch.baddbmm(self.b1, inputs, self.w1)
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
s1 = math.sqrt(0.1 / d_model) out_inter = self.drop(out_act)
s2 = math.sqrt(0.1 / d_ff)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1) outputs = self.drop(out_model) # outputs [el, gc, h]
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2) outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
nn.init.trunc_normal_(self.b2, std=s2) return outputs
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate) class TPExperts(MoeExperts):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
self.w1.__setattr__('moe_info', self.dist_info) case that the number of experts can't be divied by maximum expert parallel size or
self.w2.__setattr__('moe_info', self.dist_info) maximum expert parallel size can't be divied by the number of experts.
self.b1.__setattr__('moe_info', self.dist_info) """
def forward(self, inputs): # inputs [g, e, c, h] def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
e = inputs.size(1)
h = inputs.size(-1) assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
"d_ff should be divied by maximum expert parallel size"
inputs = inputs.transpose(0, 1)
inshape = inputs.shape p_ff = d_ff // MOE_CONTEXT.max_ep_size
inputs = inputs.reshape(e, -1, h)
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
out_ff = torch.baddbmm(self.b1, inputs, self.w1) self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR): self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
out_inter = self.drop(out_act) self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
out_model = torch.baddbmm(self.b2, out_inter, self.w2) s1 = math.sqrt(0.1 / d_model)
outputs = self.drop(out_model) # outputs [e, gc, h] s2 = math.sqrt(0.1 / d_ff)
outputs = outputs.reshape(inshape) with seed(ParallelMode.TENSOR):
outputs = outputs.transpose(0, 1).contiguous() nn.init.trunc_normal_(self.w1, std=s1)
return outputs # outputs [g, e, c, h] nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__('moe_info', self.dist_info)
self.w2.__setattr__('moe_info', self.dist_info)
self.b1.__setattr__('moe_info', self.dist_info)
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]

View File

@ -1,203 +1,210 @@
import math import math
from typing import Optional, Tuple, Type
import torch
import torch.nn as nn import torch
import torch.nn.functional as F import torch.nn as nn
from colossalai.context.moe_context import MOE_CONTEXT import torch.nn.functional as F
from colossalai.utils import get_current_device
from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \ from colossalai.context.moe_context import MOE_CONTEXT
ReduceScatter, MoeDispatch, MoeCombine from colossalai.nn.layer.moe._operation import (
from colossalai.nn.layer.moe.experts import MoeExperts, Experts COL_MOE_KERNEL_FLAG,
from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator AllGather,
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router AllToAll,
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator MoeCombine,
from typing import Optional, Type, Tuple MoeDispatch,
ReduceScatter,
)
@no_shard_zero_decrator(is_replicated=True) from colossalai.nn.layer.moe.experts import Experts, MoeExperts
class MoeLayer(nn.Module): from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
"""A MoE layer, that puts its input tensor to its gate and uses the output logits from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
to router all tokens, is mainly used to exchange all tokens for every expert across from colossalai.utils import get_current_device
the moe tensor group by all to all comunication. Then it will get the output of all from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
experts and exchange the output. At last returns the output of the moe system.
Args: @no_shard_zero_decrator(is_replicated=True)
dim_model (int): Dimension of model. class MoeLayer(nn.Module):
num_experts (int): The number of experts. """A MoE layer, that puts its input tensor to its gate and uses the output logits
router (MoeRouter): Instance of router used in routing. to router all tokens, is mainly used to exchange all tokens for every expert across
experts (MoeExperts): Instance of experts generated by Expert. the moe tensor group by all to all comunication. Then it will get the output of all
""" experts and exchange the output. At last returns the output of the moe system.
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): Args:
super().__init__() dim_model (int): Dimension of model.
self.d_model = dim_model num_experts (int): The number of experts.
self.num_experts = num_experts router (MoeRouter): Instance of router used in routing.
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) experts (MoeExperts): Instance of experts generated by Expert.
self.router: MoeRouter = router """
self.experts: MoeExperts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
self.ep_group = experts.dist_info.ep_group super().__init__()
self.ep_size = experts.dist_info.ep_size self.d_model = dim_model
self.num_local_experts = experts.num_local_experts self.num_experts = num_experts
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) self.router: MoeRouter = router
self.experts: MoeExperts = experts
def a2a_process(self, dispatch_data: torch.Tensor): self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
expert_input = AllToAll.apply(dispatch_data, self.ep_group) self.ep_group = experts.dist_info.ep_group
input_shape = expert_input.shape self.ep_size = experts.dist_info.ep_size
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) self.num_local_experts = experts.num_local_experts
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape) nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
def tp_process(self, dispatch_data: torch.Tensor): input_shape = expert_input.shape
expert_in = AllGather.apply(dispatch_data, self.ep_group) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
expert_out = self.experts(expert_in) expert_output = self.experts(expert_input)
expert_out = ReduceScatter.apply(expert_out, self.ep_group) expert_output = expert_output.reshape(input_shape)
return expert_out expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output
def forward(self, inputs: torch.Tensor) -> Tuple:
# reshape the input tokens def tp_process(self, dispatch_data: torch.Tensor):
tokens = inputs.reshape(-1, self.d_model) expert_in = AllGather.apply(dispatch_data, self.ep_group)
expert_out = self.experts(expert_in)
# the data type of the inputs in the gating should be fp32 expert_out = ReduceScatter.apply(expert_out, self.ep_group)
fp32_input = tokens.to(torch.float) return expert_out
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight) def forward(self, inputs: torch.Tensor) -> Tuple:
# reshape the input tokens
# the result from the router tokens = inputs.reshape(-1, self.d_model)
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
# the data type of the inputs in the gating should be fp32
if self.use_kernel: fp32_input = tokens.to(torch.float)
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) fp32_weight = self.gate_weight.to(torch.float)
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) gate_output = F.linear(fp32_input, fp32_weight)
else:
sec_mask_f = route_result_list[1].type_as(inputs) # the result from the router
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
# dispatch_data [e, c, h] if self.use_kernel:
if self.experts.comm_name == "all_to_all": dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
expert_output = self.a2a_process(dispatch_data) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
elif self.experts.comm_name == "all_gather": else:
expert_output = self.tp_process(dispatch_data) sec_mask_f = route_result_list[1].type_as(inputs)
else: dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.") # dispatch_data [e, c, h]
# expert_output [e, c, h] if self.experts.comm_name == "all_to_all":
if self.use_kernel: expert_output = self.a2a_process(dispatch_data)
expert_output = expert_output.reshape(-1, self.d_model) elif self.experts.comm_name == "all_gather":
ans = MoeCombine.apply(expert_output, *route_result_list) expert_output = self.tp_process(dispatch_data)
else: else:
combine_weights = route_result_list[0].type_as(inputs) raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
combine_weights = combine_weights.view(combine_weights.shape[0], -1) "build function.")
expert_output = expert_output.view(-1, expert_output.shape[-1]) # expert_output [e, c, h]
ans = torch.matmul(combine_weights, expert_output) if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model)
ans = ans.reshape(inputs.shape) ans = MoeCombine.apply(expert_output, *route_result_list)
l_aux = self.router.pop_routing_loss() else:
return ans, l_aux combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
class MoeModule(nn.Module): ans = torch.matmul(combine_weights, expert_output)
"""A class for users to create MoE modules in their models.
ans = ans.reshape(inputs.shape)
Args: l_aux = self.router.pop_routing_loss()
dim_model (int): Hidden dimension of training model return ans, l_aux
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training class MoeModule(nn.Module):
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation """A class for users to create MoE modules in their models.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. Args:
'Jitter' can be found in `Switch Transformer paper`_. dim_model (int): Hidden dimension of training model
'Gaussian' can be found in `ViT-MoE paper`_. num_experts (int): The number experts
drop_tks (bool, optional): Whether drops tokens in evaluation top_k (int, optional): The number of experts for dispatchment of each token
use_residual (bool, optional): Makes this MoE layer a Residual MoE. capacity_factor_train (float, optional): Capacity factor in routing during training
More information can be found in `Microsoft paper`_. capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE min_capacity (int, optional): The minimum number of the capacity of each expert
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given 'Jitter' can be found in `Switch Transformer paper`_.
expert_args (optional): The args of expert when no instance is given 'Gaussian' can be found in `ViT-MoE paper`_.
drop_tks (bool, optional): Whether drops tokens in evaluation
.. _Switch Transformer paper: use_residual (bool, optional): Makes this MoE layer a Residual MoE.
https://arxiv.org/abs/2101.03961 More information can be found in `Microsoft paper`_.
.. _ViT-MoE paper: residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
https://arxiv.org/abs/2106.05974 expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
.. _Microsoft paper: expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
https://arxiv.org/abs/2201.05596 expert_args (optional): The args of expert when no instance is given
"""
.. _Switch Transformer paper:
def __init__(self, https://arxiv.org/abs/2101.03961
dim_model: int, .. _ViT-MoE paper:
num_experts: int, https://arxiv.org/abs/2106.05974
top_k: int = 1, .. _Microsoft paper:
capacity_factor_train: float = 1.25, https://arxiv.org/abs/2201.05596
capacity_factor_eval: float = 2.0, """
min_capacity: int = 4,
noisy_policy: Optional[str] = None, def __init__(self,
drop_tks: bool = True, dim_model: int,
use_residual: bool = False, num_experts: int,
residual_instance: Optional[nn.Module] = None, top_k: int = 1,
expert_instance: Optional[MoeExperts] = None, capacity_factor_train: float = 1.25,
expert_cls: Optional[Type[nn.Module]] = None, capacity_factor_eval: float = 2.0,
**expert_args): min_capacity: int = 4,
super().__init__() noisy_policy: Optional[str] = None,
drop_tks: bool = True,
noisy_func = None use_residual: bool = False,
if noisy_policy is not None: residual_instance: Optional[nn.Module] = None,
if noisy_policy == 'Jitter': expert_instance: Optional[MoeExperts] = None,
noisy_func = UniformNoiseGenerator() expert_cls: Optional[Type[nn.Module]] = None,
elif noisy_policy == 'Gaussian': **expert_args):
noisy_func = NormalNoiseGenerator(num_experts) super().__init__()
else:
raise NotImplementedError("Unsupported input noisy policy") noisy_func = None
if noisy_policy is not None:
if top_k == 1: if noisy_policy == 'Jitter':
moe_router_cls = Top1Router noisy_func = UniformNoiseGenerator()
elif top_k == 2: elif noisy_policy == 'Gaussian':
moe_router_cls = Top2Router noisy_func = NormalNoiseGenerator(num_experts)
else: else:
raise NotImplementedError("top_k > 2 is not supported yet") raise NotImplementedError("Unsupported input noisy policy")
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, if top_k == 1:
capacity_factor_eval=capacity_factor_eval, moe_router_cls = Top1Router
min_capacity=min_capacity, elif top_k == 2:
noisy_func=noisy_func, moe_router_cls = Top2Router
drop_tks=drop_tks) else:
self.use_residual = use_residual raise NotImplementedError("top_k > 2 is not supported yet")
if use_residual:
if residual_instance is not None: self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
self.residual_module = residual_instance capacity_factor_eval=capacity_factor_eval,
else: min_capacity=min_capacity,
assert expert_cls is not None, \ noisy_func=noisy_func,
"Expert class can't be None when residual instance is not given" drop_tks=drop_tks)
self.residual_module = expert_cls(**expert_args) self.use_residual = use_residual
if use_residual:
with no_shard_zero_context(): if residual_instance is not None:
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) self.residual_module = residual_instance
else:
if expert_instance is not None: assert expert_cls is not None, \
self.experts = expert_instance "Expert class can't be None when residual instance is not given"
else: self.residual_module = expert_cls(**expert_args)
assert expert_cls is not None, \
"Expert class can't be None when experts instance is not given" with no_shard_zero_context():
self.experts = Experts(expert_cls, num_experts, **expert_args) self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
self.moe_layer = MoeLayer(dim_model=dim_model, if expert_instance is not None:
num_experts=num_experts, my_experts = expert_instance
router=self.moe_router, else:
experts=self.experts) assert expert_cls is not None, \
"Expert class can't be None when experts instance is not given"
def forward(self, inputs: torch.Tensor): my_experts = Experts(expert_cls, num_experts, **expert_args)
moe_output, l_aux = self.moe_layer(inputs)
self.moe_layer = MoeLayer(dim_model=dim_model,
if self.use_residual: num_experts=num_experts,
residual_output = self.residual_module(inputs) router=self.moe_router,
combine_coef = self.residual_combine(inputs) experts=my_experts)
combine_coef = F.softmax(combine_coef, dim=-1)
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] def forward(self, inputs: torch.Tensor):
else: moe_output, l_aux = self.moe_layer(inputs)
output = moe_output
if self.use_residual:
return output, l_aux residual_output = self.residual_module(inputs)
combine_coef = self.residual_combine(inputs)
combine_coef = F.softmax(combine_coef, dim=-1)
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
else:
output = moe_output
return output, l_aux

View File

@ -0,0 +1,54 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer.moe import load_moe_model, save_moe_model
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_tensor.common_utils import debug_print
from tests.test_zero.common import CONFIG
def exam_moe_checkpoint():
with ColoInitContext(device=get_current_device()):
model = MoeModel(checkpoint=True)
save_moe_model(model, 'temp_path.pth')
with ColoInitContext(device=get_current_device()):
other_model = MoeModel(checkpoint=True)
load_moe_model(other_model, 'temp_path.pth')
state_0 = model.state_dict()
state_1 = other_model.state_dict()
for k, v in state_0.items():
u = state_1.get(k)
assert torch.equal(u.data, v.data)
if dist.get_rank() == 0:
os.remove('temp_path.pth')
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
exam_moe_checkpoint()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_checkpoint(world_size=4)