ColossalAI/colossalai/moe/layers.py

401 lines
16 KiB
Python

import dataclasses
import math
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import MLPExperts
from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
class SparseMLP(nn.Module):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in `Switch Transformer paper`_.
'Gaussian' can be found in `ViT-MoE paper`_.
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in `Microsoft paper`_.
residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
.. _Switch Transformer paper:
https://arxiv.org/abs/2101.03961
.. _ViT-MoE paper:
https://arxiv.org/abs/2106.05974
.. _Microsoft paper:
https://arxiv.org/abs/2201.05596
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
router_noisy_policy: Optional[str] = None,
router_drop_tks: bool = True,
mlp_activation: Optional[str] = None,
mlp_gated: bool = False,
enable_load_balance: bool = False,
load_balance_tolerance: float = 0.1,
load_balance_beam_width: int = 8,
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
return_gate_logits: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.gated = mlp_gated
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
self.router_loss = router_loss
self.router_norm = router_norm
# moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
router_cls = get_router_cls(router_top_k)
self.topk = router_top_k
self.router: MoeRouter = router_cls(
capacity_factor_train=router_capacity_factor_train,
capacity_factor_eval=router_capacity_factor_eval,
min_capacity=router_min_capacity,
noisy_func=noisy_func,
drop_tks=router_drop_tks,
)
# gate
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
# moe experts
self.experts = MLPExperts(
num_experts=self.num_experts,
expert_parallel=self.expert_parallel,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
activation=mlp_activation,
gated=mlp_gated,
use_kernel=self.enable_kernel,
)
# get parallel settings
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = None
if enable_hierarchical_comm:
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
get_ep_group_ranks(self.experts)
)
self.dp_group = get_dp_group(self.experts)
else:
self.ep_group = None
self.dp_group = None
self.num_local_experts = self.experts.num_local_experts
# load balance
self.enable_load_balance = enable_load_balance
if self.enable_load_balance == True:
self.load_balancer = LoadBalancer(
experts=self.experts,
gate=self.gate_weight,
local_expert_num=self.num_local_experts,
expert_num=self.num_experts,
ep_group=self.ep_group,
dp_group=self.dp_group,
tolerance=load_balance_tolerance,
beam_width=load_balance_beam_width,
group_swap_factor=load_balance_group_swap_factor,
)
# init param
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
Returns:
torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size)
"""
# reshape the input tokens
tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32
gate_logits = F.linear(tokens, self.gate_weight)
gate_output = gate_logits.to(torch.float)
# update expert load
if self.enable_load_balance == True:
with torch.no_grad():
# TODO: optimize computation
expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1]
# TODO: bincount introduces synchronize, fix it
expert_load = torch.bincount(expert_load.view(-1))
self.load_balancer.update_load(expert_load)
# the result from the router
used_capacity, *route_result_list = self.router(
inputs=gate_output,
use_kernel=self.enable_kernel,
ep_group=self.ep_group,
use_loss=self.router_loss,
use_norm=self.router_norm,
)
# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size)
else:
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError(
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
)
if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size)
ans = MoeCombine.apply(expert_output, *route_result_list)
else:
combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
if self.return_gate_logits:
return ans, gate_logits
else:
return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0)
expert_out = self.experts(expert_in)
return expert_out
def _ep_process(
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
Expert Parallel
Args:
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
)
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
return expert_output
else:
@dataclasses.dataclass
class Capsule:
data: torch.Tensor
handle: Any = None
NUM_CHUNK = 4
NUM_STAGES = 4
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape)
chunk_data = torch.split(dispatch_data, chunk_size, dim=2)
output = torch.empty_like(dispatch_data)
offset = 0
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size
expert_out = None
# all2all last output
if _expert_out is not None:
expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None
# all2all next input
if 0 <= i < NUM_CHUNK:
_expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True))
# compute
if expert_in is not None:
expert_in.handle.wait()
_expert_out = Capsule(data=self.experts(expert_in.data), handle=None)
expert_in = None
if _expert_in is not None:
expert_in = _expert_in
_expert_in = None
return output
def _tp_process(
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
without overlap:
| C |
| A | | R |
with overlap:
| C1 || C2 || C3 || C4 |
| A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 |
where C is computation, A is all gather, R is reduce scatter.
Args:
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0]
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]
return expert_out
else:
@dataclasses.dataclass
class Capsule:
data: torch.Tensor
handle: Any
indices: Tuple
NUM_CHUNK = 4
NUM_STAGES = 4
assert (
dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)
def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]:
return (slice(idx * chunk_size, (idx + 1) * chunk_size),)
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[expert_out.indices] = expert_out.data
expert_out = None
# reduce scatter last output
if _expert_out is not None:
expert_out = Capsule(
*ReduceScatter.apply(_expert_out.data, self.ep_group, True),
indices=_expert_out.indices,
)
_expert_out = None
# all gather next input
if 0 <= i < NUM_CHUNK:
_expert_in = Capsule(
*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True),
indices=get_chunk_slice(i, chunk_size),
)
# compute
if expert_in is not None:
expert_in.handle.wait()
_expert_out = Capsule(
self.experts(expert_in.data, expert_in.indices),
handle=None,
indices=expert_in.indices,
)
expert_in = None
if _expert_in is not None:
expert_in = _expert_in
_expert_in = None
return output
def apply_load_balance(model: nn.Module, optim: Any) -> None:
"""
apply load balance to every experts in the model
"""
def _apply_recursive(module: nn.Module):
for _, sub_module in module.named_children():
if isinstance(sub_module, SparseMLP):
if sub_module.enable_load_balance == True:
sub_module.load_balancer.balance_load(optim)
_apply_recursive(sub_module)
torch.cuda.empty_cache()
_apply_recursive(model)
torch.cuda.empty_cache()