2022-02-18 12:42:31 +00:00
|
|
|
import math
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from colossalai.context import ParallelMode, seed
|
|
|
|
from colossalai.utils import get_current_device
|
2022-03-23 10:03:39 +00:00
|
|
|
from colossalai.context.moe_context import MOE_CONTEXT
|
2022-03-29 09:57:59 +00:00
|
|
|
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
2022-03-21 15:19:47 +00:00
|
|
|
from typing import Type
|
2022-02-18 12:42:31 +00:00
|
|
|
|
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
class MoeExperts(nn.Module):
|
2022-03-19 07:36:25 +00:00
|
|
|
"""Basic class for experts in MoE. It stores what kind of communication expersts use
|
|
|
|
to exchange tokens, how many experts in a single GPU and parallel information such as
|
|
|
|
expert parallel size, data parallel size and their distributed communication groups.
|
|
|
|
"""
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
def __init__(self, comm_name: str, num_experts: int):
|
2022-02-27 14:28:39 +00:00
|
|
|
super().__init__()
|
2022-03-19 07:36:25 +00:00
|
|
|
assert comm_name in {"all_to_all", "all_gather"}, \
|
2022-02-27 14:28:39 +00:00
|
|
|
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
2022-03-19 07:36:25 +00:00
|
|
|
self.comm_name = comm_name
|
|
|
|
# Get the configuration of experts' deployment and parallel information from moe contex
|
|
|
|
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
2022-02-27 14:28:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Experts(MoeExperts):
|
2022-02-18 12:42:31 +00:00
|
|
|
"""A wrapper class to create experts. It will create E experts across the
|
|
|
|
moe model parallel group, where E is the number of experts. Every expert
|
|
|
|
is a instence of the class, 'expert' in initialization parameters.
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
expert_cls (:class:`torch.nn.Module`): The class of all experts
|
|
|
|
num_experts (int): The number of experts
|
|
|
|
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
2022-02-18 12:42:31 +00:00
|
|
|
"""
|
|
|
|
|
2022-03-31 10:34:11 +00:00
|
|
|
@no_shard_zero_decrator(is_replicated=False)
|
2022-03-21 15:19:47 +00:00
|
|
|
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
2022-03-19 07:36:25 +00:00
|
|
|
super().__init__("all_to_all", num_experts)
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
# Use seed to make every expert different from others
|
|
|
|
with seed(ParallelMode.TENSOR):
|
2022-03-21 15:19:47 +00:00
|
|
|
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
# Attach parallel information for all parameters in Experts
|
2022-02-18 12:42:31 +00:00
|
|
|
for exp in self.experts:
|
|
|
|
for param in exp.parameters():
|
2022-03-19 07:36:25 +00:00
|
|
|
param.__setattr__('moe_info', self.dist_info)
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
def forward(self, inputs: torch.Tensor):
|
|
|
|
# Split inputs for each expert
|
2022-02-18 12:42:31 +00:00
|
|
|
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
|
|
|
expert_output = []
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
# Get outputs from each expert
|
2022-02-18 12:42:31 +00:00
|
|
|
for i in range(self.num_local_experts):
|
|
|
|
expert_output.append(self.experts[i](expert_input[i]))
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
# Concatenate all outputs together
|
2022-02-18 12:42:31 +00:00
|
|
|
output = torch.cat(expert_output, dim=1).contiguous()
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
class FFNExperts(MoeExperts):
|
2022-03-19 07:36:25 +00:00
|
|
|
"""Use torch.bmm to speed up for multiple experts.
|
|
|
|
"""
|
2022-02-18 12:42:31 +00:00
|
|
|
|
|
|
|
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
2022-03-19 07:36:25 +00:00
|
|
|
super().__init__("all_to_all", num_experts)
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
|
|
|
|
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
|
|
|
|
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
|
2022-02-18 12:42:31 +00:00
|
|
|
|
|
|
|
s1 = math.sqrt(0.1 / d_model)
|
|
|
|
s2 = math.sqrt(0.1 / d_ff)
|
2022-02-27 06:01:25 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
with seed(ParallelMode.TENSOR):
|
2022-02-27 06:01:25 +00:00
|
|
|
nn.init.trunc_normal_(self.w1, std=s1)
|
|
|
|
nn.init.trunc_normal_(self.b1, std=s1)
|
|
|
|
nn.init.trunc_normal_(self.w2, std=s2)
|
|
|
|
nn.init.trunc_normal_(self.b2, std=s2)
|
2022-02-18 12:42:31 +00:00
|
|
|
|
|
|
|
self.act = nn.GELU() if activation is None else activation
|
|
|
|
self.drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
for param in self.parameters():
|
2022-03-19 07:36:25 +00:00
|
|
|
param.__setattr__('moe_info', self.dist_info)
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
def forward(self, inputs): # inputs [g, el, c, h]
|
2022-02-18 12:42:31 +00:00
|
|
|
|
|
|
|
el = inputs.size(1)
|
|
|
|
h = inputs.size(-1)
|
|
|
|
|
|
|
|
inputs = inputs.transpose(0, 1)
|
|
|
|
inshape = inputs.shape
|
|
|
|
inputs = inputs.reshape(el, -1, h)
|
|
|
|
|
|
|
|
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
|
|
|
out_act = self.act(out_ff)
|
|
|
|
with seed(ParallelMode.TENSOR):
|
2022-03-19 07:36:25 +00:00
|
|
|
out_inter = self.drop(out_act)
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
2022-02-18 12:42:31 +00:00
|
|
|
with seed(ParallelMode.TENSOR):
|
|
|
|
outputs = self.drop(out_model) # outputs [el, gc, h]
|
|
|
|
|
|
|
|
outputs = outputs.reshape(inshape)
|
|
|
|
outputs = outputs.transpose(0, 1).contiguous()
|
|
|
|
return outputs
|
2022-02-27 14:28:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TPExperts(MoeExperts):
|
2022-03-19 07:36:25 +00:00
|
|
|
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
|
|
|
|
case that the number of experts can't be divied by maximum expert parallel size or
|
|
|
|
maximum expert parallel size can't be divied by the number of experts.
|
|
|
|
"""
|
2022-02-27 14:28:39 +00:00
|
|
|
|
|
|
|
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
2022-03-19 07:36:25 +00:00
|
|
|
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
|
|
|
|
"d_ff should be divied by maximum expert parallel size"
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
p_ff = d_ff // MOE_CONTEXT.max_ep_size
|
2022-02-27 14:28:39 +00:00
|
|
|
|
|
|
|
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
|
|
|
|
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
|
|
|
|
|
|
|
|
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
|
|
|
|
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
|
|
|
|
|
|
|
|
s1 = math.sqrt(0.1 / d_model)
|
|
|
|
s2 = math.sqrt(0.1 / d_ff)
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
with seed(ParallelMode.TENSOR):
|
2022-02-27 14:28:39 +00:00
|
|
|
nn.init.trunc_normal_(self.w1, std=s1)
|
|
|
|
nn.init.trunc_normal_(self.b1, std=s1)
|
|
|
|
nn.init.trunc_normal_(self.w2, std=s2)
|
|
|
|
|
|
|
|
nn.init.trunc_normal_(self.b2, std=s2)
|
|
|
|
|
|
|
|
self.act = nn.GELU() if activation is None else activation
|
|
|
|
self.drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
self.w1.__setattr__('moe_info', self.dist_info)
|
|
|
|
self.w2.__setattr__('moe_info', self.dist_info)
|
|
|
|
self.b1.__setattr__('moe_info', self.dist_info)
|
2022-02-27 14:28:39 +00:00
|
|
|
|
|
|
|
def forward(self, inputs): # inputs [g, e, c, h]
|
|
|
|
|
|
|
|
e = inputs.size(1)
|
|
|
|
h = inputs.size(-1)
|
|
|
|
|
|
|
|
inputs = inputs.transpose(0, 1)
|
|
|
|
inshape = inputs.shape
|
|
|
|
inputs = inputs.reshape(e, -1, h)
|
|
|
|
|
|
|
|
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
|
|
|
out_act = self.act(out_ff)
|
|
|
|
with seed(ParallelMode.TENSOR):
|
2022-03-19 07:36:25 +00:00
|
|
|
out_inter = self.drop(out_act)
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
2022-02-27 14:28:39 +00:00
|
|
|
outputs = self.drop(out_model) # outputs [e, gc, h]
|
|
|
|
|
|
|
|
outputs = outputs.reshape(inshape)
|
|
|
|
outputs = outputs.transpose(0, 1).contiguous()
|
|
|
|
return outputs # outputs [g, e, c, h]
|