ColossalAI/colossalai/nn/layer/moe/experts.py

171 lines
6.8 KiB
Python
Raw Normal View History

import math
import torch
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
from colossalai.context.moe_context import MOE_CONTEXT
from typing import Type
2022-02-27 14:28:39 +00:00
class MoeExperts(nn.Module):
2022-03-19 07:36:25 +00:00
"""Basic class for experts in MoE. It stores what kind of communication expersts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups.
"""
2022-02-27 14:28:39 +00:00
2022-03-19 07:36:25 +00:00
def __init__(self, comm_name: str, num_experts: int):
2022-02-27 14:28:39 +00:00
super().__init__()
2022-03-19 07:36:25 +00:00
assert comm_name in {"all_to_all", "all_gather"}, \
2022-02-27 14:28:39 +00:00
"This kind of communication has not been implemented yet.\n Please use Experts build function."
2022-03-19 07:36:25 +00:00
self.comm_name = comm_name
# Get the configuration of experts' deployment and parallel information from moe contex
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
2022-02-27 14:28:39 +00:00
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
2022-03-25 05:02:39 +00:00
Args:
expert_cls (:class:`torch.nn.Module`): The class of all experts
num_experts (int): The number of experts
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
2022-03-19 07:36:25 +00:00
super().__init__("all_to_all", num_experts)
2022-02-27 14:28:39 +00:00
2022-03-19 07:36:25 +00:00
# Use seed to make every expert different from others
with seed(ParallelMode.TENSOR):
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
2022-02-27 14:28:39 +00:00
2022-03-19 07:36:25 +00:00
# Attach parallel information for all parameters in Experts
for exp in self.experts:
for param in exp.parameters():
2022-03-19 07:36:25 +00:00
param.__setattr__('moe_info', self.dist_info)
2022-02-27 14:28:39 +00:00
2022-03-19 07:36:25 +00:00
def forward(self, inputs: torch.Tensor):
# Split inputs for each expert
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
2022-03-19 07:36:25 +00:00
# Get outputs from each expert
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
2022-03-19 07:36:25 +00:00
# Concatenate all outputs together
output = torch.cat(expert_output, dim=1).contiguous()
return output
2022-02-27 14:28:39 +00:00
class FFNExperts(MoeExperts):
2022-03-19 07:36:25 +00:00
"""Use torch.bmm to speed up for multiple experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
2022-03-19 07:36:25 +00:00
super().__init__("all_to_all", num_experts)
2022-03-19 07:36:25 +00:00
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
2022-03-19 07:36:25 +00:00
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
2022-03-19 07:36:25 +00:00
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters():
2022-03-19 07:36:25 +00:00
param.__setattr__('moe_info', self.dist_info)
2022-02-27 14:28:39 +00:00
def forward(self, inputs): # inputs [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
2022-03-19 07:36:25 +00:00
out_inter = self.drop(out_act)
2022-03-19 07:36:25 +00:00
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs
2022-02-27 14:28:39 +00:00
class TPExperts(MoeExperts):
2022-03-19 07:36:25 +00:00
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divied by maximum expert parallel size or
maximum expert parallel size can't be divied by the number of experts.
"""
2022-02-27 14:28:39 +00:00
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
2022-03-19 07:36:25 +00:00
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
2022-02-27 14:28:39 +00:00
2022-03-19 07:36:25 +00:00
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
"d_ff should be divied by maximum expert parallel size"
2022-02-27 14:28:39 +00:00
2022-03-19 07:36:25 +00:00
p_ff = d_ff // MOE_CONTEXT.max_ep_size
2022-02-27 14:28:39 +00:00
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
2022-03-19 07:36:25 +00:00
with seed(ParallelMode.TENSOR):
2022-02-27 14:28:39 +00:00
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
2022-03-19 07:36:25 +00:00
self.w1.__setattr__('moe_info', self.dist_info)
self.w2.__setattr__('moe_info', self.dist_info)
self.b1.__setattr__('moe_info', self.dist_info)
2022-02-27 14:28:39 +00:00
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
2022-03-19 07:36:25 +00:00
out_inter = self.drop(out_act)
2022-02-27 14:28:39 +00:00
2022-03-19 07:36:25 +00:00
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
2022-02-27 14:28:39 +00:00
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]