2023-11-02 02:21:24 +00:00
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
from colossalai.context.singleton_meta import SingletonMeta
|
|
|
|
from colossalai.tensor.moe_tensor.api import get_moe_info
|
|
|
|
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
|
|
|
|
|
|
|
|
2023-11-08 15:07:03 +00:00
|
|
|
class MoEManager(metaclass=SingletonMeta):
|
2023-11-02 02:21:24 +00:00
|
|
|
"""MoE manager. This class manages different
|
|
|
|
parallel groups in MoE context and MoE loss in training.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.parallel = None
|
|
|
|
self.mode = None
|
|
|
|
self.use_ep_inside = None
|
|
|
|
self.world_size = None
|
|
|
|
self._parallel_info_dict = dict()
|
|
|
|
|
|
|
|
# router
|
|
|
|
self.router_aux_loss = []
|
|
|
|
self.router_z_loss = []
|
|
|
|
|
|
|
|
# fixed mode
|
|
|
|
self.pp_size = None
|
|
|
|
self.dp_size = None
|
|
|
|
self.ep_size = None
|
|
|
|
|
|
|
|
# dynamic mode
|
|
|
|
# Users may want to set maximum expert parallel size smaller than the world size
|
|
|
|
# since very low bandwidth across nodes may constrain the performance of MoE
|
|
|
|
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
|
|
|
|
self.max_ep_size = None
|
|
|
|
|
|
|
|
self.has_setup = False
|
|
|
|
|
|
|
|
@property
|
|
|
|
def parallel_info_dict(self):
|
|
|
|
return self._parallel_info_dict
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_initialized(self):
|
|
|
|
return self.has_setup
|
|
|
|
|
|
|
|
def setup(
|
|
|
|
self,
|
|
|
|
parallel: str = None,
|
|
|
|
mode: str = "dynamic",
|
|
|
|
max_ep_size: int = 8,
|
|
|
|
fixed_dp_size: int = 0,
|
|
|
|
fixed_ep_size: int = 0,
|
|
|
|
fixed_pp_size: int = 0,
|
|
|
|
use_ep_inside: bool = True,
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Setup MoE distributed context.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
seed (int): Random seed. Defaults to 42.
|
|
|
|
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
|
|
|
|
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
|
|
|
|
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
|
|
|
|
In fixed mode, the ep size and dp size is fixed.
|
|
|
|
In dynamic mode, the ep size and dp size will be changed according to num experts.
|
|
|
|
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
|
|
|
|
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
|
|
|
|
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
|
|
|
|
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
|
2024-01-25 05:56:27 +00:00
|
|
|
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Defaults to True.
|
2023-11-02 02:21:24 +00:00
|
|
|
"""
|
2023-11-08 15:07:03 +00:00
|
|
|
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
|
2023-11-02 02:21:24 +00:00
|
|
|
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
|
|
|
|
|
|
|
|
self.parallel = parallel
|
|
|
|
self.use_ep_inside = use_ep_inside
|
|
|
|
self.world_size = dist.get_world_size()
|
|
|
|
|
|
|
|
# init by mode
|
|
|
|
self.mode = mode
|
|
|
|
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
|
|
|
|
if self.mode == "dynamic":
|
|
|
|
self.max_ep_size = min(max_ep_size, self.world_size)
|
|
|
|
else:
|
2023-11-08 15:07:03 +00:00
|
|
|
assert (
|
|
|
|
fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
|
|
|
|
), "dp_size, ep_size and pp_size should be greater than 0"
|
|
|
|
assert (
|
|
|
|
isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
|
|
|
|
), "dp_size, ep_size and pp_size should be int"
|
2023-11-02 02:21:24 +00:00
|
|
|
self.ep_size = fixed_ep_size
|
|
|
|
self.dp_size = fixed_dp_size
|
|
|
|
self.pp_size = fixed_pp_size
|
|
|
|
|
|
|
|
self.has_setup = True
|
|
|
|
|
|
|
|
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
|
|
|
|
"""Calculate the Data Parallel Group and Expert Parallel Group.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
num_experts : int
|
|
|
|
The number experts
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
int, MoeParallelInfo
|
|
|
|
number of local experts, the MoeParallelInfo of the current ep_size
|
|
|
|
"""
|
|
|
|
|
|
|
|
if self.mode == "dynamic":
|
2023-11-08 15:07:03 +00:00
|
|
|
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
|
|
|
|
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
|
|
|
|
assert gt_flag or lt_flag, (
|
|
|
|
"Automatic experts placement dose not not support expert number"
|
|
|
|
" is not a multiple of ep size or vice versa."
|
|
|
|
)
|
2023-11-02 02:21:24 +00:00
|
|
|
dp_size = 1 if gt_flag else self.world_size // num_experts
|
|
|
|
ep_size = min(self.world_size // dp_size, self.max_ep_size)
|
|
|
|
dp_size = self.world_size // ep_size
|
|
|
|
pp_size = 1
|
|
|
|
else:
|
|
|
|
dp_size = self.dp_size
|
|
|
|
ep_size = self.ep_size
|
|
|
|
pp_size = self.pp_size
|
|
|
|
|
|
|
|
# Calculate the number of experts for each GPU
|
|
|
|
if use_tp:
|
|
|
|
num_local_experts = num_experts
|
|
|
|
else:
|
|
|
|
if self.mode == "dynamic":
|
|
|
|
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
|
|
|
|
else:
|
|
|
|
num_local_experts = num_experts // ep_size
|
|
|
|
|
|
|
|
if not (ep_size in self.parallel_info_dict):
|
|
|
|
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
|
|
|
|
if dist.get_rank() == 0:
|
|
|
|
if self.use_ep_inside:
|
|
|
|
print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
|
|
|
|
else:
|
|
|
|
print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
|
|
|
|
|
|
|
|
return num_local_experts, self.parallel_info_dict[ep_size]
|
|
|
|
|
|
|
|
def reset_loss(self):
|
|
|
|
self.router_aux_loss, self.router_z_loss = [], []
|
|
|
|
|
|
|
|
def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
|
|
|
|
self.router_aux_loss.append(aux_loss)
|
|
|
|
self.router_z_loss.append(z_loss)
|
|
|
|
|
|
|
|
def get_loss(self):
|
|
|
|
cur_loss = self.router_aux_loss, self.router_z_loss
|
|
|
|
return cur_loss
|
|
|
|
|
|
|
|
def get_parallel(self):
|
|
|
|
return self.parallel
|
|
|
|
|
|
|
|
|
2023-11-08 15:07:03 +00:00
|
|
|
MOE_MANAGER = MoEManager()
|