mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
392 lines
16 KiB
392 lines
16 KiB
import dataclasses |
|
import math |
|
from typing import Any, Optional, Tuple |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter |
|
from colossalai.moe.experts import MLPExperts |
|
from colossalai.moe.load_balance import LoadBalancer |
|
from colossalai.moe.manager import MOE_MANAGER |
|
from colossalai.moe.routers import MoeRouter, get_router_cls |
|
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator |
|
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size |
|
|
|
|
|
class SparseMLP(nn.Module): |
|
"""A class for users to create MoE modules in their models. |
|
|
|
Args: |
|
dim_model (int): Hidden dimension of training model |
|
num_experts (int): The number experts |
|
top_k (int, optional): The number of experts for dispatchment of each token |
|
capacity_factor_train (float, optional): Capacity factor in routing during training |
|
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation |
|
min_capacity (int, optional): The minimum number of the capacity of each expert |
|
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. |
|
'Jitter' can be found in `Switch Transformer paper`_. |
|
'Gaussian' can be found in `ViT-MoE paper`_. |
|
drop_tks (bool, optional): Whether drops tokens in evaluation |
|
use_residual (bool, optional): Makes this MoE layer a Residual MoE. |
|
More information can be found in `Microsoft paper`_. |
|
residual_instance (nn.Module, optional): The instance of residual module in Residual MoE |
|
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer |
|
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given |
|
expert_args (optional): The args of expert when no instance is given |
|
|
|
.. _Switch Transformer paper: |
|
https://arxiv.org/abs/2101.03961 |
|
.. _ViT-MoE paper: |
|
https://arxiv.org/abs/2106.05974 |
|
.. _Microsoft paper: |
|
https://arxiv.org/abs/2201.05596 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_experts: int, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
router_top_k: int = 1, |
|
router_capacity_factor_train: float = 1.25, |
|
router_capacity_factor_eval: float = 2.0, |
|
router_min_capacity: int = 4, |
|
router_noisy_policy: Optional[str] = None, |
|
router_drop_tks: bool = True, |
|
mlp_activation: Optional[str] = None, |
|
mlp_gated: bool = False, |
|
enable_load_balance: bool = False, |
|
load_balance_tolerance: float = 0.1, |
|
load_balance_beam_width: int = 8, |
|
load_balance_group_swap_factor: float = 0.4, |
|
enable_kernel: bool = False, |
|
enable_comm_overlap: bool = False, |
|
enable_hierarchical_comm: bool = False, |
|
): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.intermediate_size = intermediate_size |
|
self.num_experts = num_experts |
|
self.gated = mlp_gated |
|
self.enable_kernel = enable_kernel |
|
self.enable_comm_overlap = enable_comm_overlap |
|
self.expert_parallel = MOE_MANAGER.get_parallel() |
|
|
|
# moe router |
|
noisy_func = get_noise_generator(router_noisy_policy, num_experts) |
|
router_cls = get_router_cls(router_top_k) |
|
self.topk = router_top_k |
|
self.router: MoeRouter = router_cls( |
|
capacity_factor_train=router_capacity_factor_train, |
|
capacity_factor_eval=router_capacity_factor_eval, |
|
min_capacity=router_min_capacity, |
|
noisy_func=noisy_func, |
|
drop_tks=router_drop_tks, |
|
) |
|
|
|
# gate |
|
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) |
|
|
|
# moe experts |
|
self.experts = MLPExperts( |
|
num_experts=self.num_experts, |
|
expert_parallel=self.expert_parallel, |
|
hidden_size=self.hidden_size, |
|
intermediate_size=self.intermediate_size, |
|
activation=mlp_activation, |
|
gated=mlp_gated, |
|
use_kernel=self.enable_kernel, |
|
) |
|
|
|
# get parallel settings |
|
if self.expert_parallel is not None: |
|
self.ep_group = get_ep_group(self.experts) |
|
self.ep_size = get_ep_size(self.experts) |
|
self.ep_hierarchical_group = None |
|
if enable_hierarchical_comm: |
|
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group( |
|
get_ep_group_ranks(self.experts) |
|
) |
|
self.dp_group = get_dp_group(self.experts) |
|
else: |
|
self.ep_group = None |
|
self.dp_group = None |
|
self.num_local_experts = self.experts.num_local_experts |
|
|
|
# load balance |
|
self.enable_load_balance = enable_load_balance |
|
if self.enable_load_balance == True: |
|
self.load_balancer = LoadBalancer( |
|
experts=self.experts, |
|
gate=self.gate_weight, |
|
local_expert_num=self.num_local_experts, |
|
expert_num=self.num_experts, |
|
ep_group=self.ep_group, |
|
dp_group=self.dp_group, |
|
tolerance=load_balance_tolerance, |
|
beam_width=load_balance_beam_width, |
|
group_swap_factor=load_balance_group_swap_factor, |
|
) |
|
|
|
# init param |
|
self.reset_parameters() |
|
|
|
@torch.no_grad() |
|
def reset_parameters(self): |
|
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) |
|
|
|
Returns: |
|
torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size) |
|
""" |
|
# reshape the input tokens |
|
tokens = inputs.reshape(-1, self.hidden_size) |
|
|
|
# the data type of the inputs in the gating should be fp32 |
|
fp32_input = tokens.to(torch.float) |
|
fp32_weight = self.gate_weight.to(torch.float) |
|
gate_output = F.linear(fp32_input, fp32_weight) |
|
|
|
# update expert load |
|
if self.enable_load_balance == True: |
|
with torch.no_grad(): |
|
# TODO: optimize computation |
|
expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] |
|
# TODO: bincount introduces synchronize, fix it |
|
expert_load = torch.bincount(expert_load.view(-1)) |
|
self.load_balancer.update_load(expert_load) |
|
|
|
# the result from the router |
|
used_capacity, *route_result_list = self.router( |
|
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) |
|
|
|
# dispatch_data: (num_experts, capacity, hidden_size) |
|
if self.enable_kernel: |
|
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) |
|
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) |
|
else: |
|
sec_mask_f = route_result_list[1].type_as(inputs) |
|
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) |
|
|
|
# expert_output: (num_groups, num_experts, capacity, hidden_size) |
|
if self.expert_parallel == "EP": |
|
expert_output = self._ep_process( |
|
dispatch_data, |
|
used_capacity, |
|
overlap=self.enable_comm_overlap |
|
) |
|
elif self.expert_parallel == "TP": |
|
expert_output = self._tp_process( |
|
dispatch_data, |
|
used_capacity, |
|
overlap=self.enable_comm_overlap |
|
) |
|
elif self.expert_parallel is None: |
|
expert_output = self._local_process(dispatch_data) |
|
else: |
|
raise NotImplementedError("This kind of communication has not been implemented yet.\n" |
|
"Please use Experts build function.") |
|
|
|
if self.enable_kernel: |
|
expert_output = expert_output.reshape(-1, self.hidden_size) |
|
ans = MoeCombine.apply(expert_output, *route_result_list) |
|
else: |
|
combine_weights = route_result_list[0].type_as(inputs) |
|
combine_weights = combine_weights.view(combine_weights.shape[0], -1) |
|
expert_output = expert_output.view(-1, expert_output.shape[-1]) |
|
ans = torch.matmul(combine_weights, expert_output) |
|
|
|
ans = ans.reshape(inputs.shape) |
|
return ans |
|
|
|
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: |
|
expert_in = expert_in.unsqueeze(0) |
|
expert_out = self.experts(expert_in) |
|
return expert_out |
|
|
|
def _ep_process( |
|
self, |
|
dispatch_data: torch.Tensor, |
|
used_capacity: torch.Tensor, |
|
overlap: bool = False |
|
) -> torch.Tensor: |
|
""" |
|
Expert Parallel |
|
|
|
Args: |
|
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) |
|
|
|
Returns: |
|
torch.Tensor: (num_experts, capacity, hidden_size) |
|
""" |
|
if not overlap or dist.get_world_size(self.ep_group) == 1: |
|
if self.ep_hierarchical_group is not None: |
|
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) |
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) |
|
expert_output = self.experts(expert_input) |
|
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) |
|
return expert_output |
|
else: |
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] |
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) |
|
expert_output = self.experts(expert_input) |
|
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] |
|
return expert_output |
|
else: |
|
|
|
@dataclasses.dataclass |
|
class Capsule: |
|
data: torch.Tensor |
|
handle: Any = None |
|
|
|
NUM_CHUNK = 4 |
|
NUM_STAGES = 4 |
|
|
|
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" |
|
chunk_size = dispatch_data.shape[1] // NUM_CHUNK |
|
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) |
|
dispatch_data = dispatch_data.reshape(*input_shape) |
|
chunk_data = torch.split(dispatch_data, chunk_size, dim=2) |
|
output = torch.empty_like(dispatch_data) |
|
|
|
offset = 0 |
|
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None |
|
|
|
for i in range(NUM_CHUNK + NUM_STAGES - 1): |
|
if expert_out is not None: |
|
expert_out.handle.wait() |
|
output[:, :, offset:offset + chunk_size, :] = expert_out.data |
|
offset += chunk_size |
|
expert_out = None |
|
|
|
# all2all last output |
|
if _expert_out is not None: |
|
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) |
|
_expert_out = None |
|
|
|
# all2all next input |
|
if 0 <= i < NUM_CHUNK: |
|
_expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True)) |
|
|
|
# compute |
|
if expert_in is not None: |
|
expert_in.handle.wait() |
|
_expert_out = Capsule(data=self.experts(expert_in.data), handle=None) |
|
expert_in = None |
|
|
|
if _expert_in is not None: |
|
expert_in = _expert_in |
|
_expert_in = None |
|
|
|
return output |
|
|
|
def _tp_process( |
|
self, |
|
dispatch_data: torch.Tensor, |
|
used_capacity: torch.Tensor, |
|
overlap: bool = False |
|
) -> torch.Tensor: |
|
""" |
|
without overlap: |
|
| C | |
|
| A | | R | |
|
|
|
with overlap: |
|
| C1 || C2 || C3 || C4 | |
|
| A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | |
|
|
|
where C is computation, A is all gather, R is reduce scatter. |
|
|
|
Args: |
|
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) |
|
|
|
Returns: |
|
torch.Tensor: (num_experts, capacity, hidden_size) |
|
""" |
|
if not overlap or dist.get_world_size(self.ep_group) == 1: |
|
expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] |
|
expert_out = self.experts(expert_in) |
|
expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] |
|
return expert_out |
|
else: |
|
|
|
@dataclasses.dataclass |
|
class Capsule: |
|
data: torch.Tensor |
|
handle: Any |
|
indices: Tuple |
|
|
|
NUM_CHUNK = 4 |
|
NUM_STAGES = 4 |
|
|
|
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ |
|
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" |
|
chunk_size = dispatch_data.shape[0] // NUM_CHUNK |
|
chunk_data = torch.split(dispatch_data, chunk_size, dim=0) |
|
output = torch.empty_like(dispatch_data) |
|
|
|
def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: |
|
return (slice(idx * chunk_size, (idx + 1) * chunk_size),) |
|
|
|
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None |
|
|
|
for i in range(NUM_CHUNK + NUM_STAGES - 1): |
|
if expert_out is not None: |
|
expert_out.handle.wait() |
|
output[expert_out.indices] = expert_out.data |
|
expert_out = None |
|
|
|
# reduce scatter last output |
|
if _expert_out is not None: |
|
expert_out = Capsule( |
|
*ReduceScatter.apply(_expert_out.data, self.ep_group, True), |
|
indices=_expert_out.indices, |
|
) |
|
_expert_out = None |
|
|
|
# all gather next input |
|
if 0 <= i < NUM_CHUNK: |
|
_expert_in = Capsule( |
|
*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), |
|
indices=get_chunk_slice(i, chunk_size), |
|
) |
|
|
|
# compute |
|
if expert_in is not None: |
|
expert_in.handle.wait() |
|
_expert_out = Capsule( |
|
self.experts(expert_in.data, expert_in.indices), |
|
handle=None, |
|
indices=expert_in.indices, |
|
) |
|
expert_in = None |
|
|
|
if _expert_in is not None: |
|
expert_in = _expert_in |
|
_expert_in = None |
|
|
|
return output |
|
|
|
|
|
def apply_load_balance(model: nn.Module, optim: Any) -> None: |
|
""" |
|
apply load balance to every experts in the model |
|
""" |
|
|
|
def _apply_recursive(module: nn.Module): |
|
for _, sub_module in module.named_children(): |
|
if isinstance(sub_module, SparseMLP): |
|
if sub_module.enable_load_balance == True: |
|
sub_module.load_balancer.balance_load(optim) |
|
_apply_recursive(sub_module) |
|
|
|
torch.cuda.empty_cache() |
|
_apply_recursive(model) |
|
torch.cuda.empty_cache()
|
|
|