ColossalAI/colossalai/moe/manager.py

164 lines
6.0 KiB
Python

from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.moe_tensor.api import get_moe_info
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
class MoEManager(metaclass=SingletonMeta):
"""MoE manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
def __init__(self):
self.parallel = None
self.mode = None
self.use_ep_inside = None
self.world_size = None
self._parallel_info_dict = dict()
# router
self.router_aux_loss = []
self.router_z_loss = []
# fixed mode
self.pp_size = None
self.dp_size = None
self.ep_size = None
# dynamic mode
# Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = None
self.has_setup = False
@property
def parallel_info_dict(self):
return self._parallel_info_dict
@property
def is_initialized(self):
return self.has_setup
def setup(
self,
parallel: str = None,
mode: str = "dynamic",
max_ep_size: int = 8,
fixed_dp_size: int = 0,
fixed_ep_size: int = 0,
fixed_pp_size: int = 0,
use_ep_inside: bool = True,
) -> None:
"""
Setup MoE distributed context.
Args:
seed (int): Random seed. Defaults to 42.
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
In fixed mode, the ep size and dp size is fixed.
In dynamic mode, the ep size and dp size will be changed according to num experts.
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
"""
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.parallel = parallel
self.use_ep_inside = use_ep_inside
self.world_size = dist.get_world_size()
# init by mode
self.mode = mode
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
if self.mode == "dynamic":
self.max_ep_size = min(max_ep_size, self.world_size)
else:
assert (
fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
), "dp_size, ep_size and pp_size should be greater than 0"
assert (
isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
), "dp_size, ep_size and pp_size should be int"
self.ep_size = fixed_ep_size
self.dp_size = fixed_dp_size
self.pp_size = fixed_pp_size
self.has_setup = True
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
"""Calculate the Data Parallel Group and Expert Parallel Group.
Parameters
----------
num_experts : int
The number experts
Returns
-------
int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size
"""
if self.mode == "dynamic":
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, (
"Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
dp_size = 1 if gt_flag else self.world_size // num_experts
ep_size = min(self.world_size // dp_size, self.max_ep_size)
dp_size = self.world_size // ep_size
pp_size = 1
else:
dp_size = self.dp_size
ep_size = self.ep_size
pp_size = self.pp_size
# Calculate the number of experts for each GPU
if use_tp:
num_local_experts = num_experts
else:
if self.mode == "dynamic":
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
else:
num_local_experts = num_experts // ep_size
if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
if dist.get_rank() == 0:
if self.use_ep_inside:
print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
else:
print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
return num_local_experts, self.parallel_info_dict[ep_size]
def reset_loss(self):
self.router_aux_loss, self.router_z_loss = [], []
def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
self.router_aux_loss.append(aux_loss)
self.router_z_loss.append(z_loss)
def get_loss(self):
cur_loss = self.router_aux_loss, self.router_z_loss
return cur_loss
def get_parallel(self):
return self.parallel
MOE_MANAGER = MoEManager()