Added MoE parallel (#127)

pull/129/head
HELSON 2022-01-07 15:08:36 +08:00 committed by GitHub
parent 42741dd4a3
commit dceae85195
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 858 additions and 18 deletions

View File

@ -15,7 +15,8 @@ INITIALIZER_MAPPING = {
'2.5d': 'Initializer_2p5D', '2.5d': 'Initializer_2p5D',
'3d': 'Initializer_3D', '3d': 'Initializer_3D',
'sequence': 'Initializer_Sequence', 'sequence': 'Initializer_Sequence',
'model': 'Initializer_Model' 'model': 'Initializer_Model',
'moe': 'Initializer_Moe'
} }
# 1D parallel # 1D parallel

View File

@ -15,6 +15,7 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode
from colossalai.global_variables import moe_env
class ParallelContext: class ParallelContext:
@ -412,6 +413,13 @@ class ParallelContext:
# add this config to initialize later # add this config to initialize later
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg))
# initialization for moe environment
if parallel_config is not None and 'moe' in parallel_config:
param = parallel_config['moe']
assert 'size' in param, "Moe model parallel size should be given"
moe_env.setup(param['size'])
pg_init.append(dict(type=INITIALIZER_MAPPING['moe']))
# run initialization of different process groups # run initialization of different process groups
for initializer_cfg in pg_init: for initializer_cfg in pg_init:
cfg = initializer_cfg.copy() cfg = initializer_cfg.copy()

View File

@ -44,3 +44,7 @@ class ParallelMode(Enum):
PARALLEL_2P5D_COL = '2p5d_col' PARALLEL_2P5D_COL = '2p5d_col'
PARALLEL_2P5D_DEP = '2p5d_dep' PARALLEL_2P5D_DEP = '2p5d_dep'
PARALLEL_2P5D_XZ = '2p5d_xz' PARALLEL_2P5D_XZ = '2p5d_xz'
# MOE parallel
MOE_DATA = 'moe_data'
MOE_MODEL = 'moe_model'

View File

@ -7,10 +7,12 @@ from .initializer_pipeline import Initializer_Pipeline
from .initializer_sequence import Initializer_Sequence from .initializer_sequence import Initializer_Sequence
from .initializer_tensor import Initializer_Tensor from .initializer_tensor import Initializer_Tensor
from .initializer_model import Initializer_Model from .initializer_model import Initializer_Model
from .initializer_moe import Initializer_Moe
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
__all__ = [ __all__ = [
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D', 'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model',
'Initializer_Moe'
] ]

View File

@ -0,0 +1,97 @@
import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.global_variables import moe_env
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moemodel(ProcessGroupInitializer):
"""Model parallel initialization for MoE system.
"""
def __init__(self, moe_model, moe_data, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_model
self.moe_data = moe_data
def init_dist_group(self):
"""Initialize model parallel groups in moe parallel environment,
and assign local_ranks and groups to each gpu.
"""
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.MOE_MODEL
for i in range(self.moe_data):
ranks = [i * self.moe_model + j for j in range(self.moe_model)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moedata(ProcessGroupInitializer):
"""Data parallel initialization for MoE system.
"""
def __init__(self, moe_model, moe_data, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_model
self.moe_data = moe_data
def init_dist_group(self):
"""Initialize data parallel groups in moe parallel environment,
and assign local_ranks and groups to each gpu.
"""
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.MOE_DATA
for i in range(self.moe_model):
ranks = [i + j * self.moe_model for j in range(self.moe_data)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moe(ProcessGroupInitializer):
"""Serves as the single entry point to MoE parallel initialization.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_env.model_parallel_size
self.moe_data = moe_env.data_parallel_size
self.model_initializer = Initializer_Moemodel(
self.moe_model, self.moe_data, *args, **kwargs)
self.data_initializer = Initializer_Moedata(
self.moe_model, self.moe_data, *args, **kwargs)
def init_dist_group(self):
"""Initializes MoE parallel communication groups.
"""
parallel_setting = [self.model_initializer.init_dist_group(),
self.data_initializer.init_dist_group()]
return parallel_setting

View File

@ -1,8 +1,9 @@
from ._helper import (seed, set_mode, with_seed, add_seed, from ._helper import (seed, set_mode, with_seed, add_seed,
get_seeds, get_states, get_current_mode, get_seeds, get_states, get_current_mode,
set_seed_states, sync_states) set_seed_states, sync_states, moe_set_seed)
__all__ = [ __all__ = [
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states' 'get_states', 'get_current_mode', 'set_seed_states', 'sync_states',
'moe_set_seed'
] ]

View File

@ -49,7 +49,7 @@ def get_current_mode():
return _SEED_MANAGER.current_mode return _SEED_MANAGER.current_mode
def add_seed(parallel_mode: ParallelMode, seed: int): def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`. """Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode :param parallel_mode: The chosen parallel mode
@ -59,7 +59,7 @@ def add_seed(parallel_mode: ParallelMode, seed: int):
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
""" """
_SEED_MANAGER.add_seed(parallel_mode, seed) _SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
def set_mode(parallel_mode: ParallelMode): def set_mode(parallel_mode: ParallelMode):
@ -142,3 +142,16 @@ def with_seed(func, parallel_mode: ParallelMode):
return out return out
return wrapper return wrapper
def moe_set_seed(seed):
if torch.cuda.is_available():
from colossalai.core import global_context as gpc
moe_mp_rank = gpc.get_local_rank(ParallelMode.MOE_MODEL)
moe_mp_seed = seed + moe_mp_rank
add_seed(ParallelMode.MOE_MODEL, moe_mp_seed)
global_rank = gpc.get_global_rank()
add_seed(ParallelMode.TENSOR, global_rank, True)
print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ",
f"tensor seed {global_rank}", flush=True)

View File

@ -54,7 +54,7 @@ class SeedManager:
self._current_mode = parallel_mode self._current_mode = parallel_mode
torch.cuda.set_rng_state(self._seed_states[parallel_mode]) torch.cuda.set_rng_state(self._seed_states[parallel_mode])
def add_seed(self, parallel_mode: ParallelMode, seed: int): def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`. """Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode :param parallel_mode: The chosen parallel mode
@ -66,7 +66,11 @@ class SeedManager:
""" """
assert isinstance( assert isinstance(
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
if overwrtie is False:
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
elif parallel_mode in self._seed_states:
print(f"Warnning: {parallel_mode} seed has been overwritten.", flush=True)
current_state = torch.cuda.get_rng_state() current_state = torch.cuda.get_rng_state()
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
self._seed_states[parallel_mode] = torch.cuda.get_rng_state() self._seed_states[parallel_mode] = torch.cuda.get_rng_state()

View File

@ -2,6 +2,8 @@ from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler from ._zero_gradient_handler import ZeROGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
from ._moe_gradient_handler import MoeGradientHandler
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', __all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler'] 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
'MoeGradientHandler']

View File

@ -0,0 +1,61 @@
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from colossalai.global_variables import moe_env
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
@GRADIENT_HANDLER.register_module
class MoeGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group and
moe tensor parallel. A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
"""
def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
across moe tensor parallel group
"""
moe_data = moe_env.data_parallel_size
global_data = gpc.data_parallel_size
if global_data > 1:
# bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self._model.parameters():
if param.requires_grad and \
param.grad is not None and \
not hasattr(param, 'moe_param'):
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
# param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= gpc.get_world_size(ParallelMode.DATA)
dist.all_reduce(
coalesced, group=gpc.get_group(ParallelMode.DATA))
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
if global_data > 1:
for param in self._model.parameters():
if not param.requires_grad or param.grad is None:
continue
if moe_data > 1 and hasattr(param, 'moe_param'):
param.grad.data /= moe_data
dist.all_reduce(param.grad.data,
group=gpc.get_group(ParallelMode.MOE_DATA))

View File

@ -38,8 +38,9 @@ class BaseSchedule(ABC):
return data return data
@staticmethod @staticmethod
def _check_sanity(data, tag): def _check_sanity(data, tag: str):
assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict' assert isinstance(data, (torch.Tensor, dict)), \
f'{tag} must be torch.Tensor or dict'
def load_batch(self, data_iter, to_gpu=True): def load_batch(self, data_iter, to_gpu=True):
"""Loads a batch from data iterator. It returns the data and labels which are """Loads a batch from data iterator. It returns the data and labels which are

View File

@ -0,0 +1,36 @@
class MoeEnv:
"""Moe enviroment variable.
"""
def __init__(self):
self.data_parallel_size = None
self.model_parallel_size = None
self.aux_loss = None
def setup(self, moe_model_size):
from .core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or pipeline parallel")
assert gpc.data_parallel_size % moe_model_size == 0, \
"The size of data parallel needs to be divided by moe model parallel size"
self.data_parallel_size = gpc.data_parallel_size // moe_model_size
self.model_parallel_size = moe_model_size
def is_initialized(self):
return self.model_parallel_size is not None
def reset_loss(self):
self.aux_loss = 0
def add_loss(self, loss):
self.aux_loss += loss
def get_loss(self):
return self.aux_loss
moe_env = MoeEnv()

View File

@ -5,7 +5,6 @@ import argparse
import pprint import pprint
import os import os
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -26,6 +25,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.global_variables import moe_env
def get_default_parser(): def get_default_parser():
@ -224,7 +224,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None, lr_scheduler: _LRScheduler = None,
verbose: bool = True verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader]: ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. ''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
:param model: your model instance :param model: your model instance
@ -269,8 +269,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# first sync model across dp ranks # first sync model across dp ranks
model.to(get_current_device()) model.to(get_current_device())
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3 use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
if not use_zero3: if not moe_env.is_initialized() and not use_zero3:
sync_model_param_in_dp(model) sync_model_param_in_dp(model)
else:
print(
"Warning: The parameters of models is not automatically synchronized.\n"
"Please make sure that all parameters are the same in data parallel group.",
flush=True)
# check amp and zero # check amp and zero
fp16_cfg = gpc.config.get('fp16', None) fp16_cfg = gpc.config.get('fp16', None)
@ -327,6 +332,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"Training with zero is detected, ZeROGradientHandler is automatically " "Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration", "added even though not specified in the configuration",
ranks=[0]) ranks=[0])
elif is_using_ddp() and moe_env.is_initialized():
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA)) model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
if verbose: if verbose:

View File

@ -0,0 +1,8 @@
from ._operation import AllToAll
from .layers import Experts, MoeLayer, \
NormalNoiseGenerator, Top1Router, Top2Router
__all__ = [
'AllToAll', 'Experts', 'Top1Router', 'Top2Router',
'MoeLayer', 'NormalNoiseGenerator'
]

View File

@ -0,0 +1,29 @@
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from typing import Any, Tuple
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(ctx: Any,
inputs: Tensor,
parallel_mode: ParallelMode) -> Tensor:
ctx.parallel_mode = parallel_mode
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs,
group=gpc.get_group(parallel_mode))
return output
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.apply(*grad_outputs, ctx.parallel_mode), None

View File

@ -0,0 +1,242 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
from ._operation import AllToAll
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
class Experts(nn.Module):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([
expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', 1)
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=0)
expert_output = []
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
output = torch.cat(expert_output, dim=0)
return output
class Top1Router(nn.Module):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
"""
def __init__(self,
capacity_factor: float,
min_capacity: int,
noisy_func=None):
super().__init__()
self.capacity_factor = capacity_factor
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())).rsample
def get_capacity(self, logits_shape):
capacity = math.ceil(self.capacity_factor *
logits_shape[0] / logits_shape[1])
if capacity < self.min_capacity:
capacity = self.min_capacity
return capacity
def forward(self, inputs):
if self.noisy_func is not None:
inputs_noisy = self.noisy_func(inputs)
else:
inputs_noisy = inputs
logits = F.softmax(inputs, dim=1)
num_experts = logits.shape[1]
capacity = self.get_capacity(logits.shape)
expert_idx = torch.argmax(inputs_noisy, dim=1)
expert_mask = F.one_hot(expert_idx, num_classes=num_experts)
expert_mask_f = expert_mask.float()
exp_counts = torch.sum(expert_mask, dim=0).detach().to('cpu')
me = torch.mean(logits, dim=0)
ce = torch.mean(expert_mask_f, dim=0)
l_aux = torch.sum(me * ce) * num_experts
moe_env.add_loss(l_aux)
rand_mask = expert_mask * self.uniform(logits.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
dispatch_mask = \
expert_mask * torch.zeros_like(expert_mask).scatter_(0, dispatch_idx, 1)
locations = torch.cumsum(dispatch_mask, dim=0) - 1
locations = torch.sum(dispatch_mask * locations, dim=1)
locations = F.one_hot(locations, num_classes=capacity)
logits = logits * dispatch_mask
combine_weights = logits.unsqueeze(2) * locations.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask, exp_counts
class Top2Router(nn.Module):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
"""
def __init__(self, capacity_factor: float, noisy_func=None):
super().__init__()
self.capacity_factor = capacity_factor
self.noisy_func = noisy_func
def get_capacity(self, logits_shape):
capacity = math.ceil(2 * self.capacity_factor *
logits_shape[0] / logits_shape[1])
return capacity
def forward(self, inputs):
if self.noisy_func is not None:
inputs = self.noisy_func(inputs)
logits = F.softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
_, expert_idx = torch.topk(logits, k=2, dim=-1, largest=True, sorted=True)
top1_idx = expert_idx[:, 0]
top2_idx = expert_idx[:, 1]
mask1 = F.one_hot(top1_idx, num_classes=num_experts)
mask2 = F.one_hot(top2_idx, num_classes=num_experts)
loss_mask = (mask1 + mask2)
exp_counts = torch.sum(loss_mask, dim=0).detach().to('cpu')
me = torch.mean(logits, dim=0)
ce = torch.mean(loss_mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0
moe_env.add_loss(l_aux)
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
locations2 += torch.sum(mask1, dim=0, keepdim=True)
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
weight1 = mask1 * logits
weight2 = mask2 * logits
locations1 = torch.sum(mask1 * locations1, dim=1)
locations2 = torch.sum(mask2 * locations2, dim=1)
locations1_sc = F.one_hot(locations1, num_classes=capacity)
locations2_sc = F.one_hot(locations2, num_classes=capacity)
combine_weights1 = weight1.unsqueeze(2) * locations1_sc.unsqueeze(1)
combine_weights2 = weight2.unsqueeze(2) * locations2_sc.unsqueeze(1)
combine_weights = combine_weights1 + combine_weights2
sec_mask = combine_weights.bool()
return combine_weights, sec_mask, exp_counts
class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
the moe tensor group by all to all comunication. Then it will get the output of all
experts and exchange the output. At last returns the output of the moe system.
"""
def __init__(self,
dim_model: int,
num_experts: int,
router: nn.Module,
experts: nn.Module):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate = nn.Linear(dim_model, num_experts, device=get_current_device())
self.router = router
self.experts = experts
def _router_part(self, tokens: torch.Tensor):
gate_output = self.gate(tokens)
return self.router(gate_output)
def router_part(self, tokens: torch.Tensor):
autocast_context = torch.is_autocast_enabled()
if not autocast_context:
return self._router_part(tokens)
else:
with autocast(enabled=False):
if tokens.dtype == torch.float16:
input_tokens = tokens.float()
else:
input_tokens = tokens
return self._router_part(input_tokens)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model)
combine_weights, sec_mask, exp_counts = self.router_part(tokens)
sec_mask_f = sec_mask.type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
dispatch_data = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
expert_output = self.experts(dispatch_data)
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output)
ret = ret.reshape(inputs.shape)
return ret

View File

@ -1,3 +1,5 @@
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \
WrappedDropout, WrappedDropPath
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath'] __all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath',
'WrappedDropout', 'WrappedDropPath']

View File

@ -10,6 +10,7 @@ from torch import Tensor, dtype
from torch import nn as nn from torch import nn as nn
from ..utils import to_2tuple from ..utils import to_2tuple
from colossalai.context import seed
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):
@ -42,6 +43,58 @@ class DropPath(nn.Module):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
class WrappedDropout(nn.Module):
"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager.
"""
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
super().__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace
if mode is None:
self.func = self.nonefunc
else:
self.func = self.normalfunc
self.mode = mode
def nonefunc(self, inputs):
return F.dropout(inputs, self.p, self.training, self.inplace)
def normalfunc(self, inputs):
with seed(self.mode):
return F.dropout(inputs, self.p, self.training, self.inplace)
def forward(self, inputs):
return self.func(inputs)
class WrappedDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager.
"""
def __init__(self, p: float = 0., mode=None):
super().__init__()
self.p = p
self.mode = mode
if self.mode is None:
self.func = self.nonefunc
else:
self.func = self.normalfunc
self.mode = mode
def nonefunc(self, inputs):
return drop_path(inputs, self.p, self.training)
def normalfunc(self, inputs):
with seed(self.mode):
return drop_path(inputs, self.p, self.training)
def forward(self, inputs):
return self.func(inputs)
@LAYERS.register_module @LAYERS.register_module
class VanillaPatchEmbedding(nn.Module): class VanillaPatchEmbedding(nn.Module):
""" 2D Image to Patch Embedding """ 2D Image to Patch Embedding

View File

@ -6,6 +6,7 @@ from colossalai.nn.layer.utils import get_tensor_parallel_mode
from .loss_2d import CrossEntropyLoss2D from .loss_2d import CrossEntropyLoss2D
from .loss_2p5d import CrossEntropyLoss2p5D from .loss_2p5d import CrossEntropyLoss2p5D
from .loss_3d import CrossEntropyLoss3D from .loss_3d import CrossEntropyLoss3D
from .loss_moe import MoeCrossEntropyLoss, MoeLoss
_parallel_cross_entropy = { _parallel_cross_entropy = {
'2d': CrossEntropyLoss2D, '2d': CrossEntropyLoss2D,

View File

@ -0,0 +1,34 @@
import torch.nn as nn
from colossalai.registry import LOSSES
from torch.nn.modules.loss import _Loss
from colossalai.global_variables import moe_env
@LOSSES.register_module
class MoeCrossEntropyLoss(_Loss):
"""torch.nn.CrossEntropyLoss added with auxiliary loss.
"""
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
super().__init__()
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args):
main_loss = self.loss(*args)
aux_loss = moe_env.get_loss()
return main_loss + self.aux_weight * aux_loss
@LOSSES.register_module
class MoeLoss(_Loss):
"""A wrapper class for any loss module to add with auxiliary loss.
"""
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
super().__init__()
self.loss_fn = loss_fn(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args, **kwargs):
main_loss = self.loss_fn(*args, **kwargs)
aux_loss = moe_env.get_loss()
return main_loss + self.aux_weight * aux_loss

View File

@ -66,7 +66,7 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1): def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1):
base_scheduler = _CosineAnnealingLR( base_scheduler = _CosineAnnealingLR(
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch) optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
super().__init__(optimizer, warmup_steps, base_scheduler) super().__init__(optimizer, warmup_steps, base_scheduler)

View File

@ -17,6 +17,7 @@ import torch.distributed as dist
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
@ -91,6 +92,10 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def is_moe_parallel_parameter(p):
return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1
def _calc_l2_norm(grads): def _calc_l2_norm(grads):
norm = 0.0 norm = 0.0
if len(grads) > 0: if len(grads) > 0:
@ -165,26 +170,37 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
else: else:
tensor_parallel_grads = [] tensor_parallel_grads = []
no_tensor_parallel_grads = [] no_tensor_parallel_grads = []
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
for p in params: for p in params:
if is_model_parallel_parameter(p): if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor) tensor_parallel_grads.append(p.grad.data / reductor)
elif is_moe_parallel_parameter(p):
moe_parallel_grads.append(p.grad.data)
else: else:
no_tensor_parallel_grads.append(p.grad.data) no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0: if norm_type == 2.0:
tensor_parallel_norm = _calc_l2_norm( tensor_parallel_norm = _calc_l2_norm(
tensor_parallel_grads) ** norm_type tensor_parallel_grads) ** norm_type
no_tensor_parallel_norm = _calc_l2_norm( no_tensor_parallel_norm = _calc_l2_norm(
no_tensor_parallel_grads) ** norm_type no_tensor_parallel_grads) ** norm_type
moe_parallel_norm = _calc_l2_norm(
moe_parallel_grads) ** norm_type
else: else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_grads = _calc_lp( no_tensor_parallel_norm = _calc_lp(
no_tensor_parallel_grads, norm_type) no_tensor_parallel_grads, norm_type)
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
dist.all_reduce(tensor_parallel_norm, dist.all_reduce(tensor_parallel_norm,
op=dist.ReduceOp.SUM, op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR)) group=gpc.get_group(ParallelMode.TENSOR))
# Sum across all moe-tensor-parallel GPUs
if len(moe_parallel_grads) > 0:
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
no_tensor_parallel_norm += moe_parallel_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm, dist.all_reduce(total_norm,

26
model_zoo/helper.py Normal file
View File

@ -0,0 +1,26 @@
import torch
import torch.nn as nn
from colossalai.nn.layer import WrappedDropPath as DropPath
class TransformerLayer(nn.Module):
"""Transformer layer builder.
"""
def __init__(self,
att: nn.Module,
ffn: nn.Module,
norm1: nn.Module,
norm2: nn.Module,
droppath=None,
droppath_rate: float = 0):
super().__init__()
self.att = att
self.ffn = ffn
self.norm1 = norm1
self.norm2 = norm2
self.droppath = DropPath(droppath_rate) if droppath is None else droppath
def forward(self, x):
x = x + self.droppath(self.att(self.norm1(x)))
x = x + self.droppath(self.ffn(self.norm2(x)))
return x

View File

146
model_zoo/moe/models.py Normal file
View File

@ -0,0 +1,146 @@
import math
import torch
import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.global_variables import moe_env
from colossalai.utils import get_current_device
class VanillaSelfAttention(nn.Module):
"""Standard ViT self attention.
"""
def __init__(self,
d_model: int,
n_heads: int,
d_kv: int,
attention_drop: float = 0,
drop_rate: float = 0,
bias: bool = True,
dropout1=None,
dropout2=None):
super().__init__()
self.n_heads = n_heads
self.d_kv = d_kv
self.scale = 1.0 / math.sqrt(self.d_kv)
self.dense1 = nn.Linear(d_model, 3 * n_heads * d_kv, bias, device=get_current_device())
self.softmax = nn.Softmax(dim=-1)
self.atten_drop = nn.Dropout(attention_drop) if dropout1 is None else dropout1
self.dense2 = nn.Linear(n_heads * d_kv, d_model, device=get_current_device())
self.dropout = nn.Dropout(drop_rate) if dropout2 is None else dropout2
def forward(self, x):
qkv = self.dense1(x)
new_shape = qkv.shape[:2] + (3, self.n_heads, self.d_kv)
qkv = qkv.view(*new_shape)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[:]
x = torch.matmul(q, k.transpose(-2, -1)) * self.scale
x = self.atten_drop(self.softmax(x))
x = torch.matmul(x, v)
x = x.transpose(1, 2)
new_shape = x.shape[:2] + (self.n_heads * self.d_kv,)
x = x.reshape(*new_shape)
x = self.dense2(x)
x = self.dropout(x)
return x
class VanillaFFN(nn.Module):
"""FFN composed with two linear layers, also called MLP.
"""
def __init__(self,
d_model: int,
d_ff: int,
activation=None,
drop_rate: float = 0,
bias: bool = True,
dropout1=None,
dropout2=None):
super().__init__()
dense1 = nn.Linear(d_model, d_ff, bias, device=get_current_device())
act = nn.GELU() if activation is None else activation
dense2 = nn.Linear(d_ff, d_model, bias, device=get_current_device())
drop1 = nn.Dropout(drop_rate) if dropout1 is None else dropout1
drop2 = nn.Dropout(drop_rate) if dropout2 is None else dropout2
self.ffn = nn.Sequential(
dense1, act, drop1,dense2, drop2)
def forward(self, x):
return self.ffn(x)
class Widenet(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
d_model: int = 768,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 3072,
attention_drop: float = 0.,
drop_rate: float = 0.1,
drop_path: float = 0.):
super().__init__()
embedding = VanillaPatchEmbedding(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
shared_sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv,
attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = Experts(expert=VanillaFFN,
num_experts=num_experts,
**moe_mlp_args(
d_model=d_model,
d_ff=d_ff,
drop_rate=drop_rate
))
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [
TransformerLayer(
att=shared_sa,
ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
router=shared_router, experts=shared_experts),
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)
)
for i in range(depth)
]
norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model,
num_classes=num_classes)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
def forward(self, x):
moe_env.reset_loss()
x = self.widenet(x)
x = torch.mean(x, dim=1)
x = self.linear(x)
return x

41
model_zoo/moe/util.py Normal file
View File

@ -0,0 +1,41 @@
from colossalai.context import ParallelMode
from colossalai.nn.layer import WrappedDropout as Dropout
def moe_sa_args(d_model: int,
n_heads: int,
d_kv: int,
attention_drop: float = 0,
drop_rate: float = 0,
bias: bool = True):
"""This is an example for args in moe self attention, since lots of modules should be
adapted before putting them in experts.
"""
dropout1 = Dropout(attention_drop, mode=ParallelMode.TENSOR)
dropout2 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
return dict(
d_model=d_model,
n_heads=n_heads,
d_kv=d_kv,
bias=bias,
dropout1=dropout1,
dropout2=dropout2
)
def moe_mlp_args(d_model: int,
d_ff: int,
drop_rate: float,
bias: bool = True):
"""This is an example for args of MLP in Experts, since lots of modules should be adapted
before putting them in experts.
"""
dropout1 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
dropout2 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
return dict(
d_model=d_model,
d_ff=d_ff,
bias=bias,
dropout1=dropout1,
dropout2=dropout2
)