Browse Source

[NFC] polish colossalai/context/moe_context.py code style (#2693)

pull/2694/head
LuGY 2 years ago committed by GitHub
parent
commit
56ff1921e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 258
      colossalai/context/moe_context.py

258
colossalai/context/moe_context.py

@ -1,129 +1,129 @@
import torch from typing import Tuple
import torch.distributed as dist
import torch
from colossalai.context.parallel_mode import ParallelMode import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor import ProcessGroup from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from typing import Tuple from colossalai.tensor import ProcessGroup
def _check_sanity(): def _check_sanity():
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or " raise NotImplementedError("Moe is not compatible with tensor or "
"pipeline parallel at present.") "pipeline parallel at present.")
class MoeParallelInfo: class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups. """Moe parallelism information, storing parallel sizes and groups.
""" """
def __init__(self, ep_size: int, dp_size: int): def __init__(self, ep_size: int, dp_size: int):
_check_sanity() _check_sanity()
self.ep_size = ep_size self.ep_size = ep_size
self.dp_size = dp_size self.dp_size = dp_size
self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
self.ep_group = self.pg.tp_process_group() self.ep_group = self.pg.tp_process_group()
self.dp_group = self.pg.dp_process_group() self.dp_group = self.pg.dp_process_group()
class MoeContext(metaclass=SingletonMeta): class MoeContext(metaclass=SingletonMeta):
"""MoE parallel context manager. This class manages different """MoE parallel context manager. This class manages different
parallel groups in MoE context and MoE loss in training. parallel groups in MoE context and MoE loss in training.
""" """
def __init__(self): def __init__(self):
self.world_size = 1 self.world_size = 1
# Users may want to set maximum expert parallel size smaller than the world size # Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE # since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally # When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = 1 self.max_ep_size = 1
self.min_dp_size = 1 self.min_dp_size = 1
self.aux_loss = None self.aux_loss = None
self.use_kernel_optim = True self.use_kernel_optim = True
self.has_setup = False self.has_setup = False
self._parallel_info_dict = dict() self._parallel_info_dict = dict()
@property @property
def parallel_info_dict(self): def parallel_info_dict(self):
return self._parallel_info_dict return self._parallel_info_dict
@property @property
def is_initialized(self): def is_initialized(self):
return self.has_setup return self.has_setup
def setup(self, seed: int, use_kernel_optim: bool = True): def setup(self, seed: int, use_kernel_optim: bool = True):
assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
_check_sanity() _check_sanity()
assert torch.cuda.is_available(), "MoE requires to enable CUDA first" assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
assert self.world_size % self.max_ep_size == 0, \ assert self.world_size % self.max_ep_size == 0, \
"Maximum epxert parallel size must be a factor of the number of GPUs" "Maximum epxert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases # Enabling kernel optimization may raise error in some cases
# Users can close kernel optimization manually # Users can close kernel optimization manually
self.use_kernel_optim = use_kernel_optim self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed from .random import moe_set_seed
moe_set_seed(seed) moe_set_seed(seed)
self.has_setup = True self.has_setup = True
def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
"""Calculate the Data Parallel Group and Expert Parallel Group. """Calculate the Data Parallel Group and Expert Parallel Group.
Parameters Parameters
---------- ----------
num_experts : int num_experts : int
The number experts The number experts
Returns Returns
------- -------
int, MoeParallelInfo int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size number of local experts, the MoeParallelInfo of the current ep_size
""" """
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
" is not a multiple of ep size or vice versa." " is not a multiple of ep size or vice versa."
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts # there are multiple experts in each GPU and each GPU has different experts
# So it's data parallel size is 1 # So it's data parallel size is 1
# Otherwise, there is only one expert in each GPU # Otherwise, there is only one expert in each GPU
# The data parallel size should be calculated # The data parallel size should be calculated
dp_size = 1 if gt_flag else self.max_ep_size // num_experts dp_size = 1 if gt_flag else self.max_ep_size // num_experts
ep_size = self.max_ep_size // dp_size ep_size = self.max_ep_size // dp_size
# Calculate the number of experts for each GPU # Calculate the number of experts for each GPU
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
# Don't forget to multiply minimum data parallel size # Don't forget to multiply minimum data parallel size
dp_size *= self.min_dp_size dp_size *= self.min_dp_size
if not (ep_size in self.parallel_info_dict): if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
return num_local_experts, self.parallel_info_dict[ep_size] return num_local_experts, self.parallel_info_dict[ep_size]
def set_kernel_not_use(self): def set_kernel_not_use(self):
self.use_kernel_optim = False self.use_kernel_optim = False
def reset_loss(self): def reset_loss(self):
self.aux_loss = 0 self.aux_loss = 0
def add_loss(self, loss): def add_loss(self, loss):
self.aux_loss += loss self.aux_loss += loss
def get_loss(self): def get_loss(self):
return self.aux_loss return self.aux_loss
MOE_CONTEXT = MoeContext() MOE_CONTEXT = MoeContext()

Loading…
Cancel
Save