mirror of https://github.com/hpcaitech/ColossalAI
[moe] add checkpoint for moe models (#3354)
* [moe] add checkpoint for moe models * [hotfix] fix bugs in unit testpull/3367/head
parent
fee2af8610
commit
1a1d68b053
|
@ -1,9 +1,10 @@
|
||||||
from .experts import Experts, FFNExperts, TPExperts
|
from .checkpoint import load_moe_model, save_moe_model
|
||||||
from .layers import MoeLayer, MoeModule
|
from .experts import Experts, FFNExperts, TPExperts
|
||||||
from .routers import MoeRouter, Top1Router, Top2Router
|
from .layers import MoeLayer, MoeModule
|
||||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
|
from .routers import MoeRouter, Top1Router, Top2Router
|
||||||
|
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
|
||||||
__all__ = [
|
|
||||||
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
__all__ = [
|
||||||
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter'
|
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
||||||
]
|
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model'
|
||||||
|
]
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .experts import MoeExperts
|
||||||
|
|
||||||
|
|
||||||
|
def save_moe_model(model: nn.Module, save_path: str):
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
torch.save(state_dict, save_path)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def load_moe_model(model: nn.Module, load_path: str):
|
||||||
|
state_dict = torch.load(load_path)
|
||||||
|
|
||||||
|
for prefix, module in model.named_modules():
|
||||||
|
if prefix.endswith('.moe_layer.experts'):
|
||||||
|
# this module should be an Experts instance
|
||||||
|
assert isinstance(module, MoeExperts)
|
||||||
|
|
||||||
|
ep_rank = dist.get_rank(module.dist_info.ep_group)
|
||||||
|
num_local = module.num_local_experts
|
||||||
|
for i in range(num_local):
|
||||||
|
expert_id = ep_rank * num_local + i
|
||||||
|
for name, _ in module.experts[i].named_parameters():
|
||||||
|
cur_key = f'{prefix}.experts.{i}.{name}'
|
||||||
|
param_key = f'{prefix}.experts.{expert_id}.{name}'
|
||||||
|
load_param = state_dict[param_key]
|
||||||
|
state_dict[cur_key] = load_param
|
||||||
|
|
||||||
|
for name, _ in module.experts[0].named_parameters():
|
||||||
|
pop_pre = f'{prefix}.experts.'
|
||||||
|
pop_suf = f'.{name}'
|
||||||
|
for i in range(num_local, module.num_total_experts):
|
||||||
|
pop_key = f'{pop_pre}{i}{pop_suf}'
|
||||||
|
state_dict.pop(pop_key)
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
|
@ -1,172 +1,203 @@
|
||||||
import math
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
import torch
|
from typing import Type
|
||||||
import torch.nn as nn
|
|
||||||
from colossalai.context import ParallelMode, seed
|
import torch
|
||||||
from colossalai.utils import get_current_device
|
import torch.distributed as dist
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
import torch.nn as nn
|
||||||
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
|
||||||
from typing import Type
|
from colossalai.context import ParallelMode, seed
|
||||||
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
class MoeExperts(nn.Module):
|
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
||||||
"""Basic class for experts in MoE. It stores what kind of communication expersts use
|
|
||||||
to exchange tokens, how many experts in a single GPU and parallel information such as
|
|
||||||
expert parallel size, data parallel size and their distributed communication groups.
|
class MoeExperts(nn.Module):
|
||||||
"""
|
"""Basic class for experts in MoE. It stores what kind of communication expersts use
|
||||||
|
to exchange tokens, how many experts in a single GPU and parallel information such as
|
||||||
def __init__(self, comm_name: str, num_experts: int):
|
expert parallel size, data parallel size and their distributed communication groups.
|
||||||
super().__init__()
|
"""
|
||||||
assert comm_name in {"all_to_all", "all_gather"}, \
|
|
||||||
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
def __init__(self, comm_name: str, num_experts: int):
|
||||||
self.comm_name = comm_name
|
super().__init__()
|
||||||
# Get the configuration of experts' deployment and parallel information from moe contex
|
assert comm_name in {"all_to_all", "all_gather"}, \
|
||||||
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
||||||
|
self.comm_name = comm_name
|
||||||
|
self.num_total_experts = num_experts
|
||||||
@no_shard_zero_decrator(is_replicated=False)
|
# Get the configuration of experts' deployment and parallel information from moe contex
|
||||||
class Experts(MoeExperts):
|
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
||||||
"""A wrapper class to create experts. It will create E experts across the
|
|
||||||
moe model parallel group, where E is the number of experts. Every expert
|
|
||||||
is a instence of the class, 'expert' in initialization parameters.
|
@no_shard_zero_decrator(is_replicated=False)
|
||||||
|
class Experts(MoeExperts):
|
||||||
Args:
|
"""A wrapper class to create experts. It will create E experts across the
|
||||||
expert_cls (:class:`torch.nn.Module`): The class of all experts
|
moe model parallel group, where E is the number of experts. Every expert
|
||||||
num_experts (int): The number of experts
|
is a instence of the class, 'expert' in initialization parameters.
|
||||||
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
|
||||||
"""
|
Args:
|
||||||
|
expert_cls (:class:`torch.nn.Module`): The class of all experts
|
||||||
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
num_experts (int): The number of experts
|
||||||
super().__init__("all_to_all", num_experts)
|
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
||||||
|
"""
|
||||||
# Use seed to make every expert different from others
|
|
||||||
with seed(ParallelMode.TENSOR):
|
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
||||||
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
|
super().__init__("all_to_all", num_experts)
|
||||||
|
|
||||||
# Attach parallel information for all parameters in Experts
|
# Use seed to make every expert different from others
|
||||||
for exp in self.experts:
|
with seed(ParallelMode.TENSOR):
|
||||||
for param in exp.parameters():
|
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
|
||||||
param.__setattr__('moe_info', self.dist_info)
|
|
||||||
|
# Attach parallel information for all parameters in Experts
|
||||||
def forward(self, inputs: torch.Tensor):
|
for exp in self.experts:
|
||||||
# Split inputs for each expert
|
for param in exp.parameters():
|
||||||
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
param.__setattr__('moe_info', self.dist_info)
|
||||||
expert_output = []
|
|
||||||
|
def forward(self, inputs: torch.Tensor):
|
||||||
# Get outputs from each expert
|
# Split inputs for each expert
|
||||||
for i in range(self.num_local_experts):
|
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
||||||
expert_output.append(self.experts[i](expert_input[i]))
|
expert_output = []
|
||||||
|
|
||||||
# Concatenate all outputs together
|
# Get outputs from each expert
|
||||||
output = torch.cat(expert_output, dim=1).contiguous()
|
for i in range(self.num_local_experts):
|
||||||
return output
|
expert_output.append(self.experts[i](expert_input[i]))
|
||||||
|
|
||||||
|
# Concatenate all outputs together
|
||||||
class FFNExperts(MoeExperts):
|
output = torch.cat(expert_output, dim=1).contiguous()
|
||||||
"""Use torch.bmm to speed up for multiple experts.
|
return output
|
||||||
"""
|
|
||||||
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
assert keep_vars == False, "Only support keep_vars=False now"
|
||||||
super().__init__("all_to_all", num_experts)
|
dp_rank = dist.get_rank(self.dist_info.dp_group)
|
||||||
|
ep_rank = dist.get_rank(self.dist_info.ep_group)
|
||||||
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
|
submodule_dict = dict()
|
||||||
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
|
example_submodule = None
|
||||||
|
for name, subm in self.experts.named_modules():
|
||||||
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
|
if subm is self.experts:
|
||||||
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
|
continue
|
||||||
|
module_number = self.num_local_experts * ep_rank + int(name)
|
||||||
s1 = math.sqrt(0.1 / d_model)
|
submodule_dict[module_number] = subm
|
||||||
s2 = math.sqrt(0.1 / d_ff)
|
example_submodule = subm
|
||||||
|
|
||||||
with seed(ParallelMode.TENSOR):
|
if dp_rank == 0:
|
||||||
nn.init.trunc_normal_(self.w1, std=s1)
|
local_prefix = prefix + 'experts.'
|
||||||
nn.init.trunc_normal_(self.b1, std=s1)
|
buffer_module = deepcopy(example_submodule)
|
||||||
nn.init.trunc_normal_(self.w2, std=s2)
|
for i in range(self.num_total_experts):
|
||||||
nn.init.trunc_normal_(self.b2, std=s2)
|
source_rank = i // self.num_local_experts
|
||||||
|
current_prefix = local_prefix + str(i) + '.'
|
||||||
self.act = nn.GELU() if activation is None else activation
|
comm_module = submodule_dict.get(i, buffer_module)
|
||||||
self.drop = nn.Dropout(p=drop_rate)
|
for name, param in comm_module.named_parameters():
|
||||||
|
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
|
||||||
for param in self.parameters():
|
if ep_rank == 0:
|
||||||
param.__setattr__('moe_info', self.dist_info)
|
destination[current_prefix + name] = param.data.cpu()
|
||||||
|
|
||||||
def forward(self, inputs): # inputs [g, el, c, h]
|
dist.barrier()
|
||||||
|
|
||||||
el = inputs.size(1)
|
|
||||||
h = inputs.size(-1)
|
class FFNExperts(MoeExperts):
|
||||||
|
"""Use torch.bmm to speed up for multiple experts.
|
||||||
inputs = inputs.transpose(0, 1)
|
"""
|
||||||
inshape = inputs.shape
|
|
||||||
inputs = inputs.reshape(el, -1, h)
|
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||||
|
super().__init__("all_to_all", num_experts)
|
||||||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
|
||||||
out_act = self.act(out_ff)
|
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
|
||||||
with seed(ParallelMode.TENSOR):
|
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
|
||||||
out_inter = self.drop(out_act)
|
|
||||||
|
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
|
||||||
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
|
||||||
with seed(ParallelMode.TENSOR):
|
|
||||||
outputs = self.drop(out_model) # outputs [el, gc, h]
|
s1 = math.sqrt(0.1 / d_model)
|
||||||
|
s2 = math.sqrt(0.1 / d_ff)
|
||||||
outputs = outputs.reshape(inshape)
|
|
||||||
outputs = outputs.transpose(0, 1).contiguous()
|
with seed(ParallelMode.TENSOR):
|
||||||
return outputs
|
nn.init.trunc_normal_(self.w1, std=s1)
|
||||||
|
nn.init.trunc_normal_(self.b1, std=s1)
|
||||||
|
nn.init.trunc_normal_(self.w2, std=s2)
|
||||||
class TPExperts(MoeExperts):
|
nn.init.trunc_normal_(self.b2, std=s2)
|
||||||
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
|
|
||||||
case that the number of experts can't be divied by maximum expert parallel size or
|
self.act = nn.GELU() if activation is None else activation
|
||||||
maximum expert parallel size can't be divied by the number of experts.
|
self.drop = nn.Dropout(p=drop_rate)
|
||||||
"""
|
|
||||||
|
for param in self.parameters():
|
||||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
param.__setattr__('moe_info', self.dist_info)
|
||||||
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
|
|
||||||
|
def forward(self, inputs): # inputs [g, el, c, h]
|
||||||
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
|
|
||||||
"d_ff should be divied by maximum expert parallel size"
|
el = inputs.size(1)
|
||||||
|
h = inputs.size(-1)
|
||||||
p_ff = d_ff // MOE_CONTEXT.max_ep_size
|
|
||||||
|
inputs = inputs.transpose(0, 1)
|
||||||
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
|
inshape = inputs.shape
|
||||||
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
|
inputs = inputs.reshape(el, -1, h)
|
||||||
|
|
||||||
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
|
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||||
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
|
out_act = self.act(out_ff)
|
||||||
|
with seed(ParallelMode.TENSOR):
|
||||||
s1 = math.sqrt(0.1 / d_model)
|
out_inter = self.drop(out_act)
|
||||||
s2 = math.sqrt(0.1 / d_ff)
|
|
||||||
|
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
nn.init.trunc_normal_(self.w1, std=s1)
|
outputs = self.drop(out_model) # outputs [el, gc, h]
|
||||||
nn.init.trunc_normal_(self.b1, std=s1)
|
|
||||||
nn.init.trunc_normal_(self.w2, std=s2)
|
outputs = outputs.reshape(inshape)
|
||||||
|
outputs = outputs.transpose(0, 1).contiguous()
|
||||||
nn.init.trunc_normal_(self.b2, std=s2)
|
return outputs
|
||||||
|
|
||||||
self.act = nn.GELU() if activation is None else activation
|
|
||||||
self.drop = nn.Dropout(p=drop_rate)
|
class TPExperts(MoeExperts):
|
||||||
|
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
|
||||||
self.w1.__setattr__('moe_info', self.dist_info)
|
case that the number of experts can't be divied by maximum expert parallel size or
|
||||||
self.w2.__setattr__('moe_info', self.dist_info)
|
maximum expert parallel size can't be divied by the number of experts.
|
||||||
self.b1.__setattr__('moe_info', self.dist_info)
|
"""
|
||||||
|
|
||||||
def forward(self, inputs): # inputs [g, e, c, h]
|
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||||
|
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
|
||||||
e = inputs.size(1)
|
|
||||||
h = inputs.size(-1)
|
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
|
||||||
|
"d_ff should be divied by maximum expert parallel size"
|
||||||
inputs = inputs.transpose(0, 1)
|
|
||||||
inshape = inputs.shape
|
p_ff = d_ff // MOE_CONTEXT.max_ep_size
|
||||||
inputs = inputs.reshape(e, -1, h)
|
|
||||||
|
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
|
||||||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
|
||||||
out_act = self.act(out_ff)
|
|
||||||
with seed(ParallelMode.TENSOR):
|
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
|
||||||
out_inter = self.drop(out_act)
|
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
|
||||||
|
|
||||||
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
s1 = math.sqrt(0.1 / d_model)
|
||||||
outputs = self.drop(out_model) # outputs [e, gc, h]
|
s2 = math.sqrt(0.1 / d_ff)
|
||||||
|
|
||||||
outputs = outputs.reshape(inshape)
|
with seed(ParallelMode.TENSOR):
|
||||||
outputs = outputs.transpose(0, 1).contiguous()
|
nn.init.trunc_normal_(self.w1, std=s1)
|
||||||
return outputs # outputs [g, e, c, h]
|
nn.init.trunc_normal_(self.b1, std=s1)
|
||||||
|
nn.init.trunc_normal_(self.w2, std=s2)
|
||||||
|
|
||||||
|
nn.init.trunc_normal_(self.b2, std=s2)
|
||||||
|
|
||||||
|
self.act = nn.GELU() if activation is None else activation
|
||||||
|
self.drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
self.w1.__setattr__('moe_info', self.dist_info)
|
||||||
|
self.w2.__setattr__('moe_info', self.dist_info)
|
||||||
|
self.b1.__setattr__('moe_info', self.dist_info)
|
||||||
|
|
||||||
|
def forward(self, inputs): # inputs [g, e, c, h]
|
||||||
|
|
||||||
|
e = inputs.size(1)
|
||||||
|
h = inputs.size(-1)
|
||||||
|
|
||||||
|
inputs = inputs.transpose(0, 1)
|
||||||
|
inshape = inputs.shape
|
||||||
|
inputs = inputs.reshape(e, -1, h)
|
||||||
|
|
||||||
|
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||||
|
out_act = self.act(out_ff)
|
||||||
|
with seed(ParallelMode.TENSOR):
|
||||||
|
out_inter = self.drop(out_act)
|
||||||
|
|
||||||
|
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||||
|
outputs = self.drop(out_model) # outputs [e, gc, h]
|
||||||
|
|
||||||
|
outputs = outputs.reshape(inshape)
|
||||||
|
outputs = outputs.transpose(0, 1).contiguous()
|
||||||
|
return outputs # outputs [g, e, c, h]
|
||||||
|
|
|
@ -1,203 +1,210 @@
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple, Type
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn as nn
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
import torch.nn.functional as F
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
ReduceScatter, MoeDispatch, MoeCombine
|
from colossalai.nn.layer.moe._operation import (
|
||||||
from colossalai.nn.layer.moe.experts import MoeExperts, Experts
|
COL_MOE_KERNEL_FLAG,
|
||||||
from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator
|
AllGather,
|
||||||
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
|
AllToAll,
|
||||||
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
MoeCombine,
|
||||||
from typing import Optional, Type, Tuple
|
MoeDispatch,
|
||||||
|
ReduceScatter,
|
||||||
|
)
|
||||||
@no_shard_zero_decrator(is_replicated=True)
|
from colossalai.nn.layer.moe.experts import Experts, MoeExperts
|
||||||
class MoeLayer(nn.Module):
|
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
|
||||||
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||||
to router all tokens, is mainly used to exchange all tokens for every expert across
|
from colossalai.utils import get_current_device
|
||||||
the moe tensor group by all to all comunication. Then it will get the output of all
|
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||||
experts and exchange the output. At last returns the output of the moe system.
|
|
||||||
|
|
||||||
Args:
|
@no_shard_zero_decrator(is_replicated=True)
|
||||||
dim_model (int): Dimension of model.
|
class MoeLayer(nn.Module):
|
||||||
num_experts (int): The number of experts.
|
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
||||||
router (MoeRouter): Instance of router used in routing.
|
to router all tokens, is mainly used to exchange all tokens for every expert across
|
||||||
experts (MoeExperts): Instance of experts generated by Expert.
|
the moe tensor group by all to all comunication. Then it will get the output of all
|
||||||
"""
|
experts and exchange the output. At last returns the output of the moe system.
|
||||||
|
|
||||||
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
|
Args:
|
||||||
super().__init__()
|
dim_model (int): Dimension of model.
|
||||||
self.d_model = dim_model
|
num_experts (int): The number of experts.
|
||||||
self.num_experts = num_experts
|
router (MoeRouter): Instance of router used in routing.
|
||||||
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
|
experts (MoeExperts): Instance of experts generated by Expert.
|
||||||
self.router: MoeRouter = router
|
"""
|
||||||
self.experts: MoeExperts = experts
|
|
||||||
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
|
||||||
self.ep_group = experts.dist_info.ep_group
|
super().__init__()
|
||||||
self.ep_size = experts.dist_info.ep_size
|
self.d_model = dim_model
|
||||||
self.num_local_experts = experts.num_local_experts
|
self.num_experts = num_experts
|
||||||
|
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
|
||||||
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
|
self.router: MoeRouter = router
|
||||||
|
self.experts: MoeExperts = experts
|
||||||
def a2a_process(self, dispatch_data: torch.Tensor):
|
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
||||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
|
self.ep_group = experts.dist_info.ep_group
|
||||||
input_shape = expert_input.shape
|
self.ep_size = experts.dist_info.ep_size
|
||||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
|
self.num_local_experts = experts.num_local_experts
|
||||||
expert_output = self.experts(expert_input)
|
|
||||||
expert_output = expert_output.reshape(input_shape)
|
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
|
||||||
expert_output = AllToAll.apply(expert_output, self.ep_group)
|
|
||||||
return expert_output
|
def a2a_process(self, dispatch_data: torch.Tensor):
|
||||||
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
|
||||||
def tp_process(self, dispatch_data: torch.Tensor):
|
input_shape = expert_input.shape
|
||||||
expert_in = AllGather.apply(dispatch_data, self.ep_group)
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
|
||||||
expert_out = self.experts(expert_in)
|
expert_output = self.experts(expert_input)
|
||||||
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
|
expert_output = expert_output.reshape(input_shape)
|
||||||
return expert_out
|
expert_output = AllToAll.apply(expert_output, self.ep_group)
|
||||||
|
return expert_output
|
||||||
def forward(self, inputs: torch.Tensor) -> Tuple:
|
|
||||||
# reshape the input tokens
|
def tp_process(self, dispatch_data: torch.Tensor):
|
||||||
tokens = inputs.reshape(-1, self.d_model)
|
expert_in = AllGather.apply(dispatch_data, self.ep_group)
|
||||||
|
expert_out = self.experts(expert_in)
|
||||||
# the data type of the inputs in the gating should be fp32
|
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
|
||||||
fp32_input = tokens.to(torch.float)
|
return expert_out
|
||||||
fp32_weight = self.gate_weight.to(torch.float)
|
|
||||||
gate_output = F.linear(fp32_input, fp32_weight)
|
def forward(self, inputs: torch.Tensor) -> Tuple:
|
||||||
|
# reshape the input tokens
|
||||||
# the result from the router
|
tokens = inputs.reshape(-1, self.d_model)
|
||||||
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
|
||||||
|
# the data type of the inputs in the gating should be fp32
|
||||||
if self.use_kernel:
|
fp32_input = tokens.to(torch.float)
|
||||||
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
|
fp32_weight = self.gate_weight.to(torch.float)
|
||||||
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
gate_output = F.linear(fp32_input, fp32_weight)
|
||||||
else:
|
|
||||||
sec_mask_f = route_result_list[1].type_as(inputs)
|
# the result from the router
|
||||||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
||||||
|
|
||||||
# dispatch_data [e, c, h]
|
if self.use_kernel:
|
||||||
if self.experts.comm_name == "all_to_all":
|
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
|
||||||
expert_output = self.a2a_process(dispatch_data)
|
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
||||||
elif self.experts.comm_name == "all_gather":
|
else:
|
||||||
expert_output = self.tp_process(dispatch_data)
|
sec_mask_f = route_result_list[1].type_as(inputs)
|
||||||
else:
|
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
|
||||||
"build function.")
|
# dispatch_data [e, c, h]
|
||||||
# expert_output [e, c, h]
|
if self.experts.comm_name == "all_to_all":
|
||||||
if self.use_kernel:
|
expert_output = self.a2a_process(dispatch_data)
|
||||||
expert_output = expert_output.reshape(-1, self.d_model)
|
elif self.experts.comm_name == "all_gather":
|
||||||
ans = MoeCombine.apply(expert_output, *route_result_list)
|
expert_output = self.tp_process(dispatch_data)
|
||||||
else:
|
else:
|
||||||
combine_weights = route_result_list[0].type_as(inputs)
|
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
||||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
"build function.")
|
||||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
# expert_output [e, c, h]
|
||||||
ans = torch.matmul(combine_weights, expert_output)
|
if self.use_kernel:
|
||||||
|
expert_output = expert_output.reshape(-1, self.d_model)
|
||||||
ans = ans.reshape(inputs.shape)
|
ans = MoeCombine.apply(expert_output, *route_result_list)
|
||||||
l_aux = self.router.pop_routing_loss()
|
else:
|
||||||
return ans, l_aux
|
combine_weights = route_result_list[0].type_as(inputs)
|
||||||
|
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||||
|
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||||
class MoeModule(nn.Module):
|
ans = torch.matmul(combine_weights, expert_output)
|
||||||
"""A class for users to create MoE modules in their models.
|
|
||||||
|
ans = ans.reshape(inputs.shape)
|
||||||
Args:
|
l_aux = self.router.pop_routing_loss()
|
||||||
dim_model (int): Hidden dimension of training model
|
return ans, l_aux
|
||||||
num_experts (int): The number experts
|
|
||||||
top_k (int, optional): The number of experts for dispatchment of each token
|
|
||||||
capacity_factor_train (float, optional): Capacity factor in routing during training
|
class MoeModule(nn.Module):
|
||||||
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
"""A class for users to create MoE modules in their models.
|
||||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
|
||||||
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
|
Args:
|
||||||
'Jitter' can be found in `Switch Transformer paper`_.
|
dim_model (int): Hidden dimension of training model
|
||||||
'Gaussian' can be found in `ViT-MoE paper`_.
|
num_experts (int): The number experts
|
||||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
top_k (int, optional): The number of experts for dispatchment of each token
|
||||||
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
|
capacity_factor_train (float, optional): Capacity factor in routing during training
|
||||||
More information can be found in `Microsoft paper`_.
|
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
||||||
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
|
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||||
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
|
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
|
||||||
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
|
'Jitter' can be found in `Switch Transformer paper`_.
|
||||||
expert_args (optional): The args of expert when no instance is given
|
'Gaussian' can be found in `ViT-MoE paper`_.
|
||||||
|
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||||
.. _Switch Transformer paper:
|
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
|
||||||
https://arxiv.org/abs/2101.03961
|
More information can be found in `Microsoft paper`_.
|
||||||
.. _ViT-MoE paper:
|
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
|
||||||
https://arxiv.org/abs/2106.05974
|
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
|
||||||
.. _Microsoft paper:
|
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
|
||||||
https://arxiv.org/abs/2201.05596
|
expert_args (optional): The args of expert when no instance is given
|
||||||
"""
|
|
||||||
|
.. _Switch Transformer paper:
|
||||||
def __init__(self,
|
https://arxiv.org/abs/2101.03961
|
||||||
dim_model: int,
|
.. _ViT-MoE paper:
|
||||||
num_experts: int,
|
https://arxiv.org/abs/2106.05974
|
||||||
top_k: int = 1,
|
.. _Microsoft paper:
|
||||||
capacity_factor_train: float = 1.25,
|
https://arxiv.org/abs/2201.05596
|
||||||
capacity_factor_eval: float = 2.0,
|
"""
|
||||||
min_capacity: int = 4,
|
|
||||||
noisy_policy: Optional[str] = None,
|
def __init__(self,
|
||||||
drop_tks: bool = True,
|
dim_model: int,
|
||||||
use_residual: bool = False,
|
num_experts: int,
|
||||||
residual_instance: Optional[nn.Module] = None,
|
top_k: int = 1,
|
||||||
expert_instance: Optional[MoeExperts] = None,
|
capacity_factor_train: float = 1.25,
|
||||||
expert_cls: Optional[Type[nn.Module]] = None,
|
capacity_factor_eval: float = 2.0,
|
||||||
**expert_args):
|
min_capacity: int = 4,
|
||||||
super().__init__()
|
noisy_policy: Optional[str] = None,
|
||||||
|
drop_tks: bool = True,
|
||||||
noisy_func = None
|
use_residual: bool = False,
|
||||||
if noisy_policy is not None:
|
residual_instance: Optional[nn.Module] = None,
|
||||||
if noisy_policy == 'Jitter':
|
expert_instance: Optional[MoeExperts] = None,
|
||||||
noisy_func = UniformNoiseGenerator()
|
expert_cls: Optional[Type[nn.Module]] = None,
|
||||||
elif noisy_policy == 'Gaussian':
|
**expert_args):
|
||||||
noisy_func = NormalNoiseGenerator(num_experts)
|
super().__init__()
|
||||||
else:
|
|
||||||
raise NotImplementedError("Unsupported input noisy policy")
|
noisy_func = None
|
||||||
|
if noisy_policy is not None:
|
||||||
if top_k == 1:
|
if noisy_policy == 'Jitter':
|
||||||
moe_router_cls = Top1Router
|
noisy_func = UniformNoiseGenerator()
|
||||||
elif top_k == 2:
|
elif noisy_policy == 'Gaussian':
|
||||||
moe_router_cls = Top2Router
|
noisy_func = NormalNoiseGenerator(num_experts)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("top_k > 2 is not supported yet")
|
raise NotImplementedError("Unsupported input noisy policy")
|
||||||
|
|
||||||
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
|
if top_k == 1:
|
||||||
capacity_factor_eval=capacity_factor_eval,
|
moe_router_cls = Top1Router
|
||||||
min_capacity=min_capacity,
|
elif top_k == 2:
|
||||||
noisy_func=noisy_func,
|
moe_router_cls = Top2Router
|
||||||
drop_tks=drop_tks)
|
else:
|
||||||
self.use_residual = use_residual
|
raise NotImplementedError("top_k > 2 is not supported yet")
|
||||||
if use_residual:
|
|
||||||
if residual_instance is not None:
|
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
|
||||||
self.residual_module = residual_instance
|
capacity_factor_eval=capacity_factor_eval,
|
||||||
else:
|
min_capacity=min_capacity,
|
||||||
assert expert_cls is not None, \
|
noisy_func=noisy_func,
|
||||||
"Expert class can't be None when residual instance is not given"
|
drop_tks=drop_tks)
|
||||||
self.residual_module = expert_cls(**expert_args)
|
self.use_residual = use_residual
|
||||||
|
if use_residual:
|
||||||
with no_shard_zero_context():
|
if residual_instance is not None:
|
||||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
self.residual_module = residual_instance
|
||||||
|
else:
|
||||||
if expert_instance is not None:
|
assert expert_cls is not None, \
|
||||||
self.experts = expert_instance
|
"Expert class can't be None when residual instance is not given"
|
||||||
else:
|
self.residual_module = expert_cls(**expert_args)
|
||||||
assert expert_cls is not None, \
|
|
||||||
"Expert class can't be None when experts instance is not given"
|
with no_shard_zero_context():
|
||||||
self.experts = Experts(expert_cls, num_experts, **expert_args)
|
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||||
|
|
||||||
self.moe_layer = MoeLayer(dim_model=dim_model,
|
if expert_instance is not None:
|
||||||
num_experts=num_experts,
|
my_experts = expert_instance
|
||||||
router=self.moe_router,
|
else:
|
||||||
experts=self.experts)
|
assert expert_cls is not None, \
|
||||||
|
"Expert class can't be None when experts instance is not given"
|
||||||
def forward(self, inputs: torch.Tensor):
|
my_experts = Experts(expert_cls, num_experts, **expert_args)
|
||||||
moe_output, l_aux = self.moe_layer(inputs)
|
|
||||||
|
self.moe_layer = MoeLayer(dim_model=dim_model,
|
||||||
if self.use_residual:
|
num_experts=num_experts,
|
||||||
residual_output = self.residual_module(inputs)
|
router=self.moe_router,
|
||||||
combine_coef = self.residual_combine(inputs)
|
experts=my_experts)
|
||||||
combine_coef = F.softmax(combine_coef, dim=-1)
|
|
||||||
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
|
def forward(self, inputs: torch.Tensor):
|
||||||
else:
|
moe_output, l_aux = self.moe_layer(inputs)
|
||||||
output = moe_output
|
|
||||||
|
if self.use_residual:
|
||||||
return output, l_aux
|
residual_output = self.residual_module(inputs)
|
||||||
|
combine_coef = self.residual_combine(inputs)
|
||||||
|
combine_coef = F.softmax(combine_coef, dim=-1)
|
||||||
|
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
|
||||||
|
else:
|
||||||
|
output = moe_output
|
||||||
|
|
||||||
|
return output, l_aux
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.context import MOE_CONTEXT
|
||||||
|
from colossalai.nn.layer.moe import load_moe_model, save_moe_model
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port, get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||||
|
from tests.test_tensor.common_utils import debug_print
|
||||||
|
from tests.test_zero.common import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def exam_moe_checkpoint():
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = MoeModel(checkpoint=True)
|
||||||
|
save_moe_model(model, 'temp_path.pth')
|
||||||
|
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
other_model = MoeModel(checkpoint=True)
|
||||||
|
load_moe_model(other_model, 'temp_path.pth')
|
||||||
|
|
||||||
|
state_0 = model.state_dict()
|
||||||
|
state_1 = other_model.state_dict()
|
||||||
|
for k, v in state_0.items():
|
||||||
|
u = state_1.get(k)
|
||||||
|
assert torch.equal(u.data, v.data)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
os.remove('temp_path.pth')
|
||||||
|
|
||||||
|
|
||||||
|
def _run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
MOE_CONTEXT.setup(seed=42)
|
||||||
|
exam_moe_checkpoint()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize("world_size", [2, 4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_moe_checkpoint(world_size):
|
||||||
|
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_moe_checkpoint(world_size=4)
|
Loading…
Reference in New Issue