mirror of https://github.com/hpcaitech/ColossalAI
Added MoE parallel (#127)
parent
42741dd4a3
commit
dceae85195
|
@ -15,7 +15,8 @@ INITIALIZER_MAPPING = {
|
|||
'2.5d': 'Initializer_2p5D',
|
||||
'3d': 'Initializer_3D',
|
||||
'sequence': 'Initializer_Sequence',
|
||||
'model': 'Initializer_Model'
|
||||
'model': 'Initializer_Model',
|
||||
'moe': 'Initializer_Moe'
|
||||
}
|
||||
|
||||
# 1D parallel
|
||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
|
|||
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
|
||||
class ParallelContext:
|
||||
|
@ -412,6 +413,13 @@ class ParallelContext:
|
|||
# add this config to initialize later
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg))
|
||||
|
||||
# initialization for moe environment
|
||||
if parallel_config is not None and 'moe' in parallel_config:
|
||||
param = parallel_config['moe']
|
||||
assert 'size' in param, "Moe model parallel size should be given"
|
||||
moe_env.setup(param['size'])
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['moe']))
|
||||
|
||||
# run initialization of different process groups
|
||||
for initializer_cfg in pg_init:
|
||||
cfg = initializer_cfg.copy()
|
||||
|
|
|
@ -44,3 +44,7 @@ class ParallelMode(Enum):
|
|||
PARALLEL_2P5D_COL = '2p5d_col'
|
||||
PARALLEL_2P5D_DEP = '2p5d_dep'
|
||||
PARALLEL_2P5D_XZ = '2p5d_xz'
|
||||
|
||||
# MOE parallel
|
||||
MOE_DATA = 'moe_data'
|
||||
MOE_MODEL = 'moe_model'
|
||||
|
|
|
@ -7,10 +7,12 @@ from .initializer_pipeline import Initializer_Pipeline
|
|||
from .initializer_sequence import Initializer_Sequence
|
||||
from .initializer_tensor import Initializer_Tensor
|
||||
from .initializer_model import Initializer_Model
|
||||
from .initializer_moe import Initializer_Moe
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
__all__ = [
|
||||
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
|
||||
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
|
||||
'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model'
|
||||
'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model',
|
||||
'Initializer_Moe'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from colossalai.global_variables import moe_env
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Moemodel(ProcessGroupInitializer):
|
||||
"""Model parallel initialization for MoE system.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_model, moe_data, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.moe_model = moe_model
|
||||
self.moe_data = moe_data
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize model parallel groups in moe parallel environment,
|
||||
and assign local_ranks and groups to each gpu.
|
||||
"""
|
||||
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.MOE_MODEL
|
||||
|
||||
for i in range(self.moe_data):
|
||||
ranks = [i * self.moe_model + j for j in range(self.moe_model)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Moedata(ProcessGroupInitializer):
|
||||
"""Data parallel initialization for MoE system.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_model, moe_data, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.moe_model = moe_model
|
||||
self.moe_data = moe_data
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize data parallel groups in moe parallel environment,
|
||||
and assign local_ranks and groups to each gpu.
|
||||
"""
|
||||
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.MOE_DATA
|
||||
|
||||
for i in range(self.moe_model):
|
||||
ranks = [i + j * self.moe_model for j in range(self.moe_data)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Moe(ProcessGroupInitializer):
|
||||
"""Serves as the single entry point to MoE parallel initialization.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.moe_model = moe_env.model_parallel_size
|
||||
self.moe_data = moe_env.data_parallel_size
|
||||
self.model_initializer = Initializer_Moemodel(
|
||||
self.moe_model, self.moe_data, *args, **kwargs)
|
||||
self.data_initializer = Initializer_Moedata(
|
||||
self.moe_model, self.moe_data, *args, **kwargs)
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initializes MoE parallel communication groups.
|
||||
"""
|
||||
|
||||
parallel_setting = [self.model_initializer.init_dist_group(),
|
||||
self.data_initializer.init_dist_group()]
|
||||
return parallel_setting
|
|
@ -1,8 +1,9 @@
|
|||
from ._helper import (seed, set_mode, with_seed, add_seed,
|
||||
get_seeds, get_states, get_current_mode,
|
||||
set_seed_states, sync_states)
|
||||
set_seed_states, sync_states, moe_set_seed)
|
||||
|
||||
__all__ = [
|
||||
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
|
||||
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states'
|
||||
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states',
|
||||
'moe_set_seed'
|
||||
]
|
||||
|
|
|
@ -49,7 +49,7 @@ def get_current_mode():
|
|||
return _SEED_MANAGER.current_mode
|
||||
|
||||
|
||||
def add_seed(parallel_mode: ParallelMode, seed: int):
|
||||
def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
|
||||
"""Adds a seed to the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
|
@ -59,7 +59,7 @@ def add_seed(parallel_mode: ParallelMode, seed: int):
|
|||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
||||
"""
|
||||
_SEED_MANAGER.add_seed(parallel_mode, seed)
|
||||
_SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
|
||||
|
||||
|
||||
def set_mode(parallel_mode: ParallelMode):
|
||||
|
@ -142,3 +142,16 @@ def with_seed(func, parallel_mode: ParallelMode):
|
|||
return out
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def moe_set_seed(seed):
|
||||
if torch.cuda.is_available():
|
||||
from colossalai.core import global_context as gpc
|
||||
moe_mp_rank = gpc.get_local_rank(ParallelMode.MOE_MODEL)
|
||||
moe_mp_seed = seed + moe_mp_rank
|
||||
add_seed(ParallelMode.MOE_MODEL, moe_mp_seed)
|
||||
|
||||
global_rank = gpc.get_global_rank()
|
||||
add_seed(ParallelMode.TENSOR, global_rank, True)
|
||||
print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ",
|
||||
f"tensor seed {global_rank}", flush=True)
|
||||
|
|
|
@ -54,7 +54,7 @@ class SeedManager:
|
|||
self._current_mode = parallel_mode
|
||||
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
|
||||
|
||||
def add_seed(self, parallel_mode: ParallelMode, seed: int):
|
||||
def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False):
|
||||
"""Adds a seed to the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
|
@ -66,7 +66,11 @@ class SeedManager:
|
|||
"""
|
||||
assert isinstance(
|
||||
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
||||
if overwrtie is False:
|
||||
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
|
||||
elif parallel_mode in self._seed_states:
|
||||
print(f"Warnning: {parallel_mode} seed has been overwritten.", flush=True)
|
||||
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.manual_seed(seed)
|
||||
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
||||
|
|
|
@ -2,6 +2,8 @@ from ._base_gradient_handler import BaseGradientHandler
|
|||
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
||||
from ._zero_gradient_handler import ZeROGradientHandler
|
||||
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from ._moe_gradient_handler import MoeGradientHandler
|
||||
|
||||
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
|
||||
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler']
|
||||
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
|
||||
'MoeGradientHandler']
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.global_variables import moe_env
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class MoeGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group and
|
||||
moe tensor parallel. A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running an all-reduce operation in a data parallel group.
|
||||
Then running an all-reduce operation for all parameters in experts
|
||||
across moe tensor parallel group
|
||||
"""
|
||||
moe_data = moe_env.data_parallel_size
|
||||
global_data = gpc.data_parallel_size
|
||||
|
||||
if global_data > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = {}
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
if param.requires_grad and \
|
||||
param.grad is not None and \
|
||||
not hasattr(param, 'moe_param'):
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
# param.main_grad = param.grad
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
coalesced /= gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
dist.all_reduce(
|
||||
coalesced, group=gpc.get_group(ParallelMode.DATA))
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
if global_data > 1:
|
||||
for param in self._model.parameters():
|
||||
if not param.requires_grad or param.grad is None:
|
||||
continue
|
||||
if moe_data > 1 and hasattr(param, 'moe_param'):
|
||||
param.grad.data /= moe_data
|
||||
dist.all_reduce(param.grad.data,
|
||||
group=gpc.get_group(ParallelMode.MOE_DATA))
|
|
@ -38,8 +38,9 @@ class BaseSchedule(ABC):
|
|||
return data
|
||||
|
||||
@staticmethod
|
||||
def _check_sanity(data, tag):
|
||||
assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict'
|
||||
def _check_sanity(data, tag: str):
|
||||
assert isinstance(data, (torch.Tensor, dict)), \
|
||||
f'{tag} must be torch.Tensor or dict'
|
||||
|
||||
def load_batch(self, data_iter, to_gpu=True):
|
||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
|
||||
|
||||
class MoeEnv:
|
||||
"""Moe enviroment variable.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data_parallel_size = None
|
||||
self.model_parallel_size = None
|
||||
self.aux_loss = None
|
||||
|
||||
def setup(self, moe_model_size):
|
||||
from .core import global_context as gpc
|
||||
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError("Moe is not compatible with tensor or pipeline parallel")
|
||||
|
||||
assert gpc.data_parallel_size % moe_model_size == 0, \
|
||||
"The size of data parallel needs to be divided by moe model parallel size"
|
||||
|
||||
self.data_parallel_size = gpc.data_parallel_size // moe_model_size
|
||||
self.model_parallel_size = moe_model_size
|
||||
|
||||
def is_initialized(self):
|
||||
return self.model_parallel_size is not None
|
||||
|
||||
def reset_loss(self):
|
||||
self.aux_loss = 0
|
||||
|
||||
def add_loss(self, loss):
|
||||
self.aux_loss += loss
|
||||
|
||||
def get_loss(self):
|
||||
return self.aux_loss
|
||||
|
||||
|
||||
moe_env = MoeEnv()
|
|
@ -5,7 +5,6 @@ import argparse
|
|||
import pprint
|
||||
import os
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -26,6 +25,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
@ -224,7 +224,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
verbose: bool = True
|
||||
) -> Tuple[Engine, DataLoader, DataLoader]:
|
||||
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
|
||||
|
||||
:param model: your model instance
|
||||
|
@ -269,8 +269,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
||||
if not use_zero3:
|
||||
if not moe_env.is_initialized() and not use_zero3:
|
||||
sync_model_param_in_dp(model)
|
||||
else:
|
||||
print(
|
||||
"Warning: The parameters of models is not automatically synchronized.\n"
|
||||
"Please make sure that all parameters are the same in data parallel group.",
|
||||
flush=True)
|
||||
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
|
@ -327,6 +332,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and moe_env.is_initialized():
|
||||
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
if verbose:
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
from ._operation import AllToAll
|
||||
from .layers import Experts, MoeLayer, \
|
||||
NormalNoiseGenerator, Top1Router, Top2Router
|
||||
|
||||
__all__ = [
|
||||
'AllToAll', 'Experts', 'Top1Router', 'Top2Router',
|
||||
'MoeLayer', 'NormalNoiseGenerator'
|
||||
]
|
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from typing import Any, Tuple
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||
operation in torch.distributed.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs,
|
||||
group=gpc.get_group(parallel_mode))
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllToAll.apply(*grad_outputs, ctx.parallel_mode), None
|
|
@ -0,0 +1,242 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import AllToAll
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logtis tensor.
|
||||
All noise is generated from a normal distribution (0, 1 / E^2), where
|
||||
E = the number of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
return inputs + noisy
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
"""A wrapper class to create experts. It will create E experts across the
|
||||
moe model parallel group, where E is the number of experts. Every expert
|
||||
is a instence of the class, 'expert' in initialization parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, expert, num_experts, **expert_args):
|
||||
super().__init__()
|
||||
|
||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
||||
"The number of experts should be divied by moe model size"
|
||||
|
||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
self.experts = nn.ModuleList([
|
||||
expert(**expert_args) for _ in range(num_local_experts)])
|
||||
self.num_local_experts = num_local_experts
|
||||
for exp in self.experts:
|
||||
for param in exp.parameters():
|
||||
param.__setattr__('moe_param', 1)
|
||||
|
||||
def forward(self, inputs):
|
||||
expert_input = torch.chunk(inputs, self.num_local_experts, dim=0)
|
||||
expert_output = []
|
||||
|
||||
for i in range(self.num_local_experts):
|
||||
expert_output.append(self.experts[i](expert_input[i]))
|
||||
|
||||
output = torch.cat(expert_output, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
class Top1Router(nn.Module):
|
||||
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More deailted function can be found in the paper about Switch Transformer
|
||||
of Google.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor: float,
|
||||
min_capacity: int,
|
||||
noisy_func=None):
|
||||
super().__init__()
|
||||
self.capacity_factor = capacity_factor
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0, device=get_current_device())).rsample
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity = math.ceil(self.capacity_factor *
|
||||
logits_shape[0] / logits_shape[1])
|
||||
if capacity < self.min_capacity:
|
||||
capacity = self.min_capacity
|
||||
return capacity
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
if self.noisy_func is not None:
|
||||
inputs_noisy = self.noisy_func(inputs)
|
||||
else:
|
||||
inputs_noisy = inputs
|
||||
|
||||
logits = F.softmax(inputs, dim=1)
|
||||
|
||||
num_experts = logits.shape[1]
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
expert_idx = torch.argmax(inputs_noisy, dim=1)
|
||||
expert_mask = F.one_hot(expert_idx, num_classes=num_experts)
|
||||
expert_mask_f = expert_mask.float()
|
||||
|
||||
exp_counts = torch.sum(expert_mask, dim=0).detach().to('cpu')
|
||||
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(expert_mask_f, dim=0)
|
||||
l_aux = torch.sum(me * ce) * num_experts
|
||||
moe_env.add_loss(l_aux)
|
||||
|
||||
rand_mask = expert_mask * self.uniform(logits.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
|
||||
dispatch_mask = \
|
||||
expert_mask * torch.zeros_like(expert_mask).scatter_(0, dispatch_idx, 1)
|
||||
|
||||
locations = torch.cumsum(dispatch_mask, dim=0) - 1
|
||||
locations = torch.sum(dispatch_mask * locations, dim=1)
|
||||
locations = F.one_hot(locations, num_classes=capacity)
|
||||
|
||||
logits = logits * dispatch_mask
|
||||
combine_weights = logits.unsqueeze(2) * locations.unsqueeze(1)
|
||||
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask, exp_counts
|
||||
|
||||
|
||||
class Top2Router(nn.Module):
|
||||
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More deailted function can be found in the paper about ViT-MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity_factor: float, noisy_func=None):
|
||||
super().__init__()
|
||||
self.capacity_factor = capacity_factor
|
||||
self.noisy_func = noisy_func
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity = math.ceil(2 * self.capacity_factor *
|
||||
logits_shape[0] / logits_shape[1])
|
||||
return capacity
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.noisy_func is not None:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
logits = F.softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
_, expert_idx = torch.topk(logits, k=2, dim=-1, largest=True, sorted=True)
|
||||
top1_idx = expert_idx[:, 0]
|
||||
top2_idx = expert_idx[:, 1]
|
||||
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts)
|
||||
|
||||
loss_mask = (mask1 + mask2)
|
||||
exp_counts = torch.sum(loss_mask, dim=0).detach().to('cpu')
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(loss_mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0
|
||||
moe_env.add_loss(l_aux)
|
||||
|
||||
locations1 = torch.cumsum(mask1, dim=0) - 1
|
||||
locations2 = torch.cumsum(mask2, dim=0) - 1
|
||||
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(locations1, capacity)
|
||||
mask2 *= torch.lt(locations2, capacity)
|
||||
|
||||
weight1 = mask1 * logits
|
||||
weight2 = mask2 * logits
|
||||
|
||||
locations1 = torch.sum(mask1 * locations1, dim=1)
|
||||
locations2 = torch.sum(mask2 * locations2, dim=1)
|
||||
locations1_sc = F.one_hot(locations1, num_classes=capacity)
|
||||
locations2_sc = F.one_hot(locations2, num_classes=capacity)
|
||||
|
||||
combine_weights1 = weight1.unsqueeze(2) * locations1_sc.unsqueeze(1)
|
||||
combine_weights2 = weight2.unsqueeze(2) * locations2_sc.unsqueeze(1)
|
||||
combine_weights = combine_weights1 + combine_weights2
|
||||
sec_mask = combine_weights.bool()
|
||||
|
||||
return combine_weights, sec_mask, exp_counts
|
||||
|
||||
|
||||
class MoeLayer(nn.Module):
|
||||
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
||||
to router all tokens, is mainly used to exchange all tokens for every expert across
|
||||
the moe tensor group by all to all comunication. Then it will get the output of all
|
||||
experts and exchange the output. At last returns the output of the moe system.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim_model: int,
|
||||
num_experts: int,
|
||||
router: nn.Module,
|
||||
experts: nn.Module):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
self.num_experts = num_experts
|
||||
self.gate = nn.Linear(dim_model, num_experts, device=get_current_device())
|
||||
self.router = router
|
||||
self.experts = experts
|
||||
|
||||
def _router_part(self, tokens: torch.Tensor):
|
||||
gate_output = self.gate(tokens)
|
||||
return self.router(gate_output)
|
||||
|
||||
def router_part(self, tokens: torch.Tensor):
|
||||
autocast_context = torch.is_autocast_enabled()
|
||||
if not autocast_context:
|
||||
return self._router_part(tokens)
|
||||
else:
|
||||
with autocast(enabled=False):
|
||||
if tokens.dtype == torch.float16:
|
||||
input_tokens = tokens.float()
|
||||
else:
|
||||
input_tokens = tokens
|
||||
return self._router_part(input_tokens)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
tokens = inputs.reshape(-1, self.d_model)
|
||||
|
||||
combine_weights, sec_mask, exp_counts = self.router_part(tokens)
|
||||
|
||||
sec_mask_f = sec_mask.type_as(inputs)
|
||||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
dispatch_data = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
|
||||
|
||||
expert_output = self.experts(dispatch_data)
|
||||
|
||||
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
|
||||
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
|
||||
ret = torch.matmul(combine_weights, expert_output)
|
||||
ret = ret.reshape(inputs.shape)
|
||||
|
||||
return ret
|
|
@ -1,3 +1,5 @@
|
|||
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding
|
||||
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \
|
||||
WrappedDropout, WrappedDropPath
|
||||
|
||||
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath']
|
||||
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath',
|
||||
'WrappedDropout', 'WrappedDropPath']
|
||||
|
|
|
@ -10,6 +10,7 @@ from torch import Tensor, dtype
|
|||
from torch import nn as nn
|
||||
|
||||
from ..utils import to_2tuple
|
||||
from colossalai.context import seed
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
|
@ -42,6 +43,58 @@ class DropPath(nn.Module):
|
|||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class WrappedDropout(nn.Module):
|
||||
"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager.
|
||||
"""
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
|
||||
super().__init__()
|
||||
if p < 0 or p > 1:
|
||||
raise ValueError("dropout probability has to be between 0 and 1, "
|
||||
"but got {}".format(p))
|
||||
self.p = p
|
||||
self.inplace = inplace
|
||||
if mode is None:
|
||||
self.func = self.nonefunc
|
||||
else:
|
||||
self.func = self.normalfunc
|
||||
self.mode = mode
|
||||
|
||||
def nonefunc(self, inputs):
|
||||
return F.dropout(inputs, self.p, self.training, self.inplace)
|
||||
|
||||
def normalfunc(self, inputs):
|
||||
with seed(self.mode):
|
||||
return F.dropout(inputs, self.p, self.training, self.inplace)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.func(inputs)
|
||||
|
||||
|
||||
class WrappedDropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
Here, it is wrapped with the context of seed manager.
|
||||
"""
|
||||
def __init__(self, p: float = 0., mode=None):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.mode = mode
|
||||
if self.mode is None:
|
||||
self.func = self.nonefunc
|
||||
else:
|
||||
self.func = self.normalfunc
|
||||
self.mode = mode
|
||||
|
||||
def nonefunc(self, inputs):
|
||||
return drop_path(inputs, self.p, self.training)
|
||||
|
||||
def normalfunc(self, inputs):
|
||||
with seed(self.mode):
|
||||
return drop_path(inputs, self.p, self.training)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.func(inputs)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaPatchEmbedding(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
|
|
|
@ -6,6 +6,7 @@ from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
|||
from .loss_2d import CrossEntropyLoss2D
|
||||
from .loss_2p5d import CrossEntropyLoss2p5D
|
||||
from .loss_3d import CrossEntropyLoss3D
|
||||
from .loss_moe import MoeCrossEntropyLoss, MoeLoss
|
||||
|
||||
_parallel_cross_entropy = {
|
||||
'2d': CrossEntropyLoss2D,
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import torch.nn as nn
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class MoeCrossEntropyLoss(_Loss):
|
||||
"""torch.nn.CrossEntropyLoss added with auxiliary loss.
|
||||
"""
|
||||
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
|
||||
self.aux_weight = aux_weight
|
||||
|
||||
def forward(self, *args):
|
||||
main_loss = self.loss(*args)
|
||||
aux_loss = moe_env.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class MoeLoss(_Loss):
|
||||
"""A wrapper class for any loss module to add with auxiliary loss.
|
||||
"""
|
||||
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_fn = loss_fn(*args, **kwargs)
|
||||
self.aux_weight = aux_weight
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
main_loss = self.loss_fn(*args, **kwargs)
|
||||
aux_loss = moe_env.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
|
@ -66,7 +66,7 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
|
|||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1):
|
||||
base_scheduler = _CosineAnnealingLR(
|
||||
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
|
||||
super().__init__(optimizer, warmup_steps, base_scheduler)
|
||||
|
|
|
@ -17,6 +17,7 @@ import torch.distributed as dist
|
|||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
|
@ -91,6 +92,10 @@ def is_model_parallel_parameter(p):
|
|||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||
|
||||
|
||||
def is_moe_parallel_parameter(p):
|
||||
return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1
|
||||
|
||||
|
||||
def _calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
|
@ -165,26 +170,37 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
else:
|
||||
tensor_parallel_grads = []
|
||||
no_tensor_parallel_grads = []
|
||||
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type)
|
||||
tensor_parallel_grads.append(p.grad.data / reductor)
|
||||
elif is_moe_parallel_parameter(p):
|
||||
moe_parallel_grads.append(p.grad.data)
|
||||
else:
|
||||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
|
||||
if norm_type == 2.0:
|
||||
tensor_parallel_norm = _calc_l2_norm(
|
||||
tensor_parallel_grads) ** norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(
|
||||
no_tensor_parallel_grads) ** norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(
|
||||
moe_parallel_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_grads = _calc_lp(
|
||||
no_tensor_parallel_norm = _calc_lp(
|
||||
no_tensor_parallel_grads, norm_type)
|
||||
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
|
||||
dist.all_reduce(tensor_parallel_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
# Sum across all moe-tensor-parallel GPUs
|
||||
if len(moe_parallel_grads) > 0:
|
||||
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
||||
no_tensor_parallel_norm += moe_parallel_norm
|
||||
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
dist.all_reduce(total_norm,
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.nn.layer import WrappedDropPath as DropPath
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
"""Transformer layer builder.
|
||||
"""
|
||||
def __init__(self,
|
||||
att: nn.Module,
|
||||
ffn: nn.Module,
|
||||
norm1: nn.Module,
|
||||
norm2: nn.Module,
|
||||
droppath=None,
|
||||
droppath_rate: float = 0):
|
||||
super().__init__()
|
||||
self.att = att
|
||||
self.ffn = ffn
|
||||
self.norm1 = norm1
|
||||
self.norm2 = norm2
|
||||
self.droppath = DropPath(droppath_rate) if droppath is None else droppath
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.droppath(self.att(self.norm1(x)))
|
||||
x = x + self.droppath(self.ffn(self.norm2(x)))
|
||||
return x
|
|
@ -0,0 +1,146 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
|
||||
WrappedDropout as Dropout, WrappedDropPath as DropPath
|
||||
from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from .util import moe_sa_args, moe_mlp_args
|
||||
from ..helper import TransformerLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class VanillaSelfAttention(nn.Module):
|
||||
"""Standard ViT self attention.
|
||||
"""
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
d_kv: int,
|
||||
attention_drop: float = 0,
|
||||
drop_rate: float = 0,
|
||||
bias: bool = True,
|
||||
dropout1=None,
|
||||
dropout2=None):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.d_kv = d_kv
|
||||
self.scale = 1.0 / math.sqrt(self.d_kv)
|
||||
|
||||
self.dense1 = nn.Linear(d_model, 3 * n_heads * d_kv, bias, device=get_current_device())
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.atten_drop = nn.Dropout(attention_drop) if dropout1 is None else dropout1
|
||||
self.dense2 = nn.Linear(n_heads * d_kv, d_model, device=get_current_device())
|
||||
self.dropout = nn.Dropout(drop_rate) if dropout2 is None else dropout2
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.dense1(x)
|
||||
new_shape = qkv.shape[:2] + (3, self.n_heads, self.d_kv)
|
||||
qkv = qkv.view(*new_shape)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[:]
|
||||
|
||||
x = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
x = self.atten_drop(self.softmax(x))
|
||||
|
||||
x = torch.matmul(x, v)
|
||||
x = x.transpose(1, 2)
|
||||
new_shape = x.shape[:2] + (self.n_heads * self.d_kv,)
|
||||
x = x.reshape(*new_shape)
|
||||
x = self.dense2(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VanillaFFN(nn.Module):
|
||||
"""FFN composed with two linear layers, also called MLP.
|
||||
"""
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
d_ff: int,
|
||||
activation=None,
|
||||
drop_rate: float = 0,
|
||||
bias: bool = True,
|
||||
dropout1=None,
|
||||
dropout2=None):
|
||||
super().__init__()
|
||||
dense1 = nn.Linear(d_model, d_ff, bias, device=get_current_device())
|
||||
act = nn.GELU() if activation is None else activation
|
||||
dense2 = nn.Linear(d_ff, d_model, bias, device=get_current_device())
|
||||
drop1 = nn.Dropout(drop_rate) if dropout1 is None else dropout1
|
||||
drop2 = nn.Dropout(drop_rate) if dropout2 is None else dropout2
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
dense1, act, drop1,dense2, drop2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ffn(x)
|
||||
|
||||
|
||||
class Widenet(nn.Module):
|
||||
def __init__(self,
|
||||
num_experts: int,
|
||||
capacity_factor: float,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
depth: int = 12,
|
||||
d_model: int = 768,
|
||||
num_heads: int = 12,
|
||||
d_kv: int = 64,
|
||||
d_ff: int = 3072,
|
||||
attention_drop: float = 0.,
|
||||
drop_rate: float = 0.1,
|
||||
drop_path: float = 0.):
|
||||
super().__init__()
|
||||
|
||||
embedding = VanillaPatchEmbedding(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_size=d_model)
|
||||
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
|
||||
|
||||
shared_sa = VanillaSelfAttention(**moe_sa_args(
|
||||
d_model=d_model, n_heads=num_heads, d_kv=d_kv,
|
||||
attention_drop=attention_drop, drop_rate=drop_rate))
|
||||
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
|
||||
shared_experts = Experts(expert=VanillaFFN,
|
||||
num_experts=num_experts,
|
||||
**moe_mlp_args(
|
||||
d_model=d_model,
|
||||
d_ff=d_ff,
|
||||
drop_rate=drop_rate
|
||||
))
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
blocks = [
|
||||
TransformerLayer(
|
||||
att=shared_sa,
|
||||
ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
|
||||
router=shared_router, experts=shared_experts),
|
||||
norm1=nn.LayerNorm(d_model, eps=1e-6),
|
||||
norm2=nn.LayerNorm(d_model, eps=1e-6),
|
||||
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.linear = VanillaClassifier(in_features=d_model,
|
||||
num_classes=num_classes)
|
||||
nn.init.zeros_(self.linear.weight)
|
||||
nn.init.zeros_(self.linear.bias)
|
||||
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
|
||||
|
||||
def forward(self, x):
|
||||
moe_env.reset_loss()
|
||||
x = self.widenet(x)
|
||||
x = torch.mean(x, dim=1)
|
||||
x = self.linear(x)
|
||||
return x
|
|
@ -0,0 +1,41 @@
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer import WrappedDropout as Dropout
|
||||
|
||||
|
||||
def moe_sa_args(d_model: int,
|
||||
n_heads: int,
|
||||
d_kv: int,
|
||||
attention_drop: float = 0,
|
||||
drop_rate: float = 0,
|
||||
bias: bool = True):
|
||||
"""This is an example for args in moe self attention, since lots of modules should be
|
||||
adapted before putting them in experts.
|
||||
"""
|
||||
dropout1 = Dropout(attention_drop, mode=ParallelMode.TENSOR)
|
||||
dropout2 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
|
||||
return dict(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
d_kv=d_kv,
|
||||
bias=bias,
|
||||
dropout1=dropout1,
|
||||
dropout2=dropout2
|
||||
)
|
||||
|
||||
|
||||
def moe_mlp_args(d_model: int,
|
||||
d_ff: int,
|
||||
drop_rate: float,
|
||||
bias: bool = True):
|
||||
"""This is an example for args of MLP in Experts, since lots of modules should be adapted
|
||||
before putting them in experts.
|
||||
"""
|
||||
dropout1 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
|
||||
dropout2 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
|
||||
return dict(
|
||||
d_model=d_model,
|
||||
d_ff=d_ff,
|
||||
bias=bias,
|
||||
dropout1=dropout1,
|
||||
dropout2=dropout2
|
||||
)
|
Loading…
Reference in New Issue