diff --git a/colossalai/constants.py b/colossalai/constants.py index 2ba535f43..0fb8ed77e 100644 --- a/colossalai/constants.py +++ b/colossalai/constants.py @@ -15,7 +15,8 @@ INITIALIZER_MAPPING = { '2.5d': 'Initializer_2p5D', '3d': 'Initializer_3D', 'sequence': 'Initializer_Sequence', - 'model': 'Initializer_Model' + 'model': 'Initializer_Model', + 'moe': 'Initializer_Moe' } # 1D parallel diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index 5bad70f00..e18ea6845 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -15,6 +15,7 @@ from colossalai.registry import DIST_GROUP_INITIALIZER from .parallel_mode import ParallelMode from .random import add_seed, get_seeds, set_mode +from colossalai.global_variables import moe_env class ParallelContext: @@ -412,6 +413,13 @@ class ParallelContext: # add this config to initialize later pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) + # initialization for moe environment + if parallel_config is not None and 'moe' in parallel_config: + param = parallel_config['moe'] + assert 'size' in param, "Moe model parallel size should be given" + moe_env.setup(param['size']) + pg_init.append(dict(type=INITIALIZER_MAPPING['moe'])) + # run initialization of different process groups for initializer_cfg in pg_init: cfg = initializer_cfg.copy() diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py index 440526eae..d50448513 100644 --- a/colossalai/context/parallel_mode.py +++ b/colossalai/context/parallel_mode.py @@ -44,3 +44,7 @@ class ParallelMode(Enum): PARALLEL_2P5D_COL = '2p5d_col' PARALLEL_2P5D_DEP = '2p5d_dep' PARALLEL_2P5D_XZ = '2p5d_xz' + + # MOE parallel + MOE_DATA = 'moe_data' + MOE_MODEL = 'moe_model' diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/context/process_group_initializer/__init__.py index b98b64310..e8262162b 100644 --- a/colossalai/context/process_group_initializer/__init__.py +++ b/colossalai/context/process_group_initializer/__init__.py @@ -7,10 +7,12 @@ from .initializer_pipeline import Initializer_Pipeline from .initializer_sequence import Initializer_Sequence from .initializer_tensor import Initializer_Tensor from .initializer_model import Initializer_Model +from .initializer_moe import Initializer_Moe from .process_group_initializer import ProcessGroupInitializer __all__ = [ 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D', - 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' + 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model', + 'Initializer_Moe' ] diff --git a/colossalai/context/process_group_initializer/initializer_moe.py b/colossalai/context/process_group_initializer/initializer_moe.py new file mode 100644 index 000000000..5ab880ec1 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_moe.py @@ -0,0 +1,97 @@ +import torch.distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER +from colossalai.global_variables import moe_env +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Moemodel(ProcessGroupInitializer): + """Model parallel initialization for MoE system. + """ + + def __init__(self, moe_model, moe_data, *args, **kwargs): + super().__init__(*args, **kwargs) + self.moe_model = moe_model + self.moe_data = moe_data + + def init_dist_group(self): + """Initialize model parallel groups in moe parallel environment, + and assign local_ranks and groups to each gpu. + """ + + local_rank = None + ranks_in_group = None + process_group = None + group_world_size = None + mode = ParallelMode.MOE_MODEL + + for i in range(self.moe_data): + ranks = [i * self.moe_model + j for j in range(self.moe_model)] + group = dist.new_group(ranks) + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Moedata(ProcessGroupInitializer): + """Data parallel initialization for MoE system. + """ + + def __init__(self, moe_model, moe_data, *args, **kwargs): + super().__init__(*args, **kwargs) + self.moe_model = moe_model + self.moe_data = moe_data + + def init_dist_group(self): + """Initialize data parallel groups in moe parallel environment, + and assign local_ranks and groups to each gpu. + """ + + local_rank = None + ranks_in_group = None + process_group = None + group_world_size = None + mode = ParallelMode.MOE_DATA + + for i in range(self.moe_model): + ranks = [i + j * self.moe_model for j in range(self.moe_data)] + group = dist.new_group(ranks) + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Moe(ProcessGroupInitializer): + """Serves as the single entry point to MoE parallel initialization. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.moe_model = moe_env.model_parallel_size + self.moe_data = moe_env.data_parallel_size + self.model_initializer = Initializer_Moemodel( + self.moe_model, self.moe_data, *args, **kwargs) + self.data_initializer = Initializer_Moedata( + self.moe_model, self.moe_data, *args, **kwargs) + + def init_dist_group(self): + """Initializes MoE parallel communication groups. + """ + + parallel_setting = [self.model_initializer.init_dist_group(), + self.data_initializer.init_dist_group()] + return parallel_setting diff --git a/colossalai/context/random/__init__.py b/colossalai/context/random/__init__.py index 29e77e3ec..675fea5aa 100644 --- a/colossalai/context/random/__init__.py +++ b/colossalai/context/random/__init__.py @@ -1,8 +1,9 @@ from ._helper import (seed, set_mode, with_seed, add_seed, get_seeds, get_states, get_current_mode, - set_seed_states, sync_states) + set_seed_states, sync_states, moe_set_seed) __all__ = [ 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', - 'get_states', 'get_current_mode', 'set_seed_states', 'sync_states' + 'get_states', 'get_current_mode', 'set_seed_states', 'sync_states', + 'moe_set_seed' ] diff --git a/colossalai/context/random/_helper.py b/colossalai/context/random/_helper.py index 1bc7af738..ba5308cdc 100644 --- a/colossalai/context/random/_helper.py +++ b/colossalai/context/random/_helper.py @@ -49,7 +49,7 @@ def get_current_mode(): return _SEED_MANAGER.current_mode -def add_seed(parallel_mode: ParallelMode, seed: int): +def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False): """Adds a seed to the seed manager for `parallel_mode`. :param parallel_mode: The chosen parallel mode @@ -59,7 +59,7 @@ def add_seed(parallel_mode: ParallelMode, seed: int): :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added """ - _SEED_MANAGER.add_seed(parallel_mode, seed) + _SEED_MANAGER.add_seed(parallel_mode, seed, overwrite) def set_mode(parallel_mode: ParallelMode): @@ -142,3 +142,16 @@ def with_seed(func, parallel_mode: ParallelMode): return out return wrapper + + +def moe_set_seed(seed): + if torch.cuda.is_available(): + from colossalai.core import global_context as gpc + moe_mp_rank = gpc.get_local_rank(ParallelMode.MOE_MODEL) + moe_mp_seed = seed + moe_mp_rank + add_seed(ParallelMode.MOE_MODEL, moe_mp_seed) + + global_rank = gpc.get_global_rank() + add_seed(ParallelMode.TENSOR, global_rank, True) + print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ", + f"tensor seed {global_rank}", flush=True) diff --git a/colossalai/context/random/seed_manager.py b/colossalai/context/random/seed_manager.py index 3e74c8cb9..33cdf25aa 100644 --- a/colossalai/context/random/seed_manager.py +++ b/colossalai/context/random/seed_manager.py @@ -54,7 +54,7 @@ class SeedManager: self._current_mode = parallel_mode torch.cuda.set_rng_state(self._seed_states[parallel_mode]) - def add_seed(self, parallel_mode: ParallelMode, seed: int): + def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False): """Adds a seed to the seed manager for `parallel_mode`. :param parallel_mode: The chosen parallel mode @@ -66,7 +66,11 @@ class SeedManager: """ assert isinstance( parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' - assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' + if overwrtie is False: + assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' + elif parallel_mode in self._seed_states: + print(f"Warnning: {parallel_mode} seed has been overwritten.", flush=True) + current_state = torch.cuda.get_rng_state() torch.cuda.manual_seed(seed) self._seed_states[parallel_mode] = torch.cuda.get_rng_state() diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py index b2fd2d442..863bb6b5b 100644 --- a/colossalai/engine/gradient_handler/__init__.py +++ b/colossalai/engine/gradient_handler/__init__.py @@ -2,6 +2,8 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler from ._zero_gradient_handler import ZeROGradientHandler from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler +from ._moe_gradient_handler import MoeGradientHandler __all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', - 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler'] + 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', + 'MoeGradientHandler'] diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py new file mode 100644 index 000000000..18456d9f7 --- /dev/null +++ b/colossalai/engine/gradient_handler/_moe_gradient_handler.py @@ -0,0 +1,61 @@ +import torch.distributed as dist +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from colossalai.global_variables import moe_env +from ._base_gradient_handler import BaseGradientHandler +from ...context.parallel_mode import ParallelMode + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe tensor parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + """ + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe tensor parallel group + """ + moe_data = moe_env.data_parallel_size + global_data = gpc.data_parallel_size + + if global_data > 1: + # bucketize and all-reduce + buckets = {} + # Pack the buckets. + for param in self._model.parameters(): + if param.requires_grad and \ + param.grad is not None and \ + not hasattr(param, 'moe_param'): + tp = param.data.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + # param.main_grad = param.grad + + # For each bucket, all-reduce and copy all-reduced grads. + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + coalesced /= gpc.get_world_size(ParallelMode.DATA) + + dist.all_reduce( + coalesced, group=gpc.get_group(ParallelMode.DATA)) + for buf, synced in zip(grads, _unflatten_dense_tensors( + coalesced, grads)): + buf.copy_(synced) + + if global_data > 1: + for param in self._model.parameters(): + if not param.requires_grad or param.grad is None: + continue + if moe_data > 1 and hasattr(param, 'moe_param'): + param.grad.data /= moe_data + dist.all_reduce(param.grad.data, + group=gpc.get_group(ParallelMode.MOE_DATA)) diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index 76c550144..90f1ac0a1 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -38,8 +38,9 @@ class BaseSchedule(ABC): return data @staticmethod - def _check_sanity(data, tag): - assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict' + def _check_sanity(data, tag: str): + assert isinstance(data, (torch.Tensor, dict)), \ + f'{tag} must be torch.Tensor or dict' def load_batch(self, data_iter, to_gpu=True): """Loads a batch from data iterator. It returns the data and labels which are diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py new file mode 100644 index 000000000..3483b6eb0 --- /dev/null +++ b/colossalai/global_variables.py @@ -0,0 +1,36 @@ + + +class MoeEnv: + """Moe enviroment variable. + """ + + def __init__(self): + self.data_parallel_size = None + self.model_parallel_size = None + self.aux_loss = None + + def setup(self, moe_model_size): + from .core import global_context as gpc + if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: + raise NotImplementedError("Moe is not compatible with tensor or pipeline parallel") + + assert gpc.data_parallel_size % moe_model_size == 0, \ + "The size of data parallel needs to be divided by moe model parallel size" + + self.data_parallel_size = gpc.data_parallel_size // moe_model_size + self.model_parallel_size = moe_model_size + + def is_initialized(self): + return self.model_parallel_size is not None + + def reset_loss(self): + self.aux_loss = 0 + + def add_loss(self, loss): + self.aux_loss += loss + + def get_loss(self): + return self.aux_loss + + +moe_env = MoeEnv() diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 39837e464..ae75ee996 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -5,7 +5,6 @@ import argparse import pprint import os from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer -import numpy as np import torch import torch.nn as nn @@ -26,6 +25,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from torch.nn.modules.loss import _Loss from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.global_variables import moe_env def get_default_parser(): @@ -224,7 +224,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]], test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, lr_scheduler: _LRScheduler = None, verbose: bool = True - ) -> Tuple[Engine, DataLoader, DataLoader]: + ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: ''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. :param model: your model instance @@ -269,8 +269,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]], # first sync model across dp ranks model.to(get_current_device()) use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3 - if not use_zero3: + if not moe_env.is_initialized() and not use_zero3: sync_model_param_in_dp(model) + else: + print( + "Warning: The parameters of models is not automatically synchronized.\n" + "Please make sure that all parameters are the same in data parallel group.", + flush=True) # check amp and zero fp16_cfg = gpc.config.get('fp16', None) @@ -327,6 +332,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]], "Training with zero is detected, ZeROGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) + elif is_using_ddp() and moe_env.is_initialized(): + gradient_handler_cfg = [dict(type='MoeGradientHandler')] + if verbose: + logger.info( + "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA)) if verbose: diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py new file mode 100644 index 000000000..e75aff6ed --- /dev/null +++ b/colossalai/nn/layer/moe/__init__.py @@ -0,0 +1,8 @@ +from ._operation import AllToAll +from .layers import Experts, MoeLayer, \ + NormalNoiseGenerator, Top1Router, Top2Router + +__all__ = [ + 'AllToAll', 'Experts', 'Top1Router', 'Top2Router', + 'MoeLayer', 'NormalNoiseGenerator' +] \ No newline at end of file diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py new file mode 100644 index 000000000..fd2720fb9 --- /dev/null +++ b/colossalai/nn/layer/moe/_operation.py @@ -0,0 +1,29 @@ +import torch +import torch.distributed as dist +from torch import Tensor + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from typing import Any, Tuple + + +class AllToAll(torch.autograd.Function): + """Dispatches input tensor [e, c, h] to all experts by all_to_all_single + operation in torch.distributed. + """ + @staticmethod + def forward(ctx: Any, + inputs: Tensor, + parallel_mode: ParallelMode) -> Tensor: + ctx.parallel_mode = parallel_mode + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + + output = torch.empty_like(inputs) + dist.all_to_all_single(output, inputs, + group=gpc.get_group(parallel_mode)) + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllToAll.apply(*grad_outputs, ctx.parallel_mode), None diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py new file mode 100644 index 000000000..51b676567 --- /dev/null +++ b/colossalai/nn/layer/moe/layers.py @@ -0,0 +1,242 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast +from colossalai.global_variables import moe_env +from colossalai.context import ParallelMode, seed +from colossalai.utils import get_current_device +from ._operation import AllToAll + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logtis tensor. + All noise is generated from a normal distribution (0, 1 / E^2), where + E = the number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal( + loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device()) + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class Experts(nn.Module): + """A wrapper class to create experts. It will create E experts across the + moe model parallel group, where E is the number of experts. Every expert + is a instence of the class, 'expert' in initialization parameters. + """ + + def __init__(self, expert, num_experts, **expert_args): + super().__init__() + + assert num_experts % moe_env.model_parallel_size == 0, \ + "The number of experts should be divied by moe model size" + + num_local_experts = num_experts // moe_env.model_parallel_size + with seed(ParallelMode.MOE_MODEL): + self.experts = nn.ModuleList([ + expert(**expert_args) for _ in range(num_local_experts)]) + self.num_local_experts = num_local_experts + for exp in self.experts: + for param in exp.parameters(): + param.__setattr__('moe_param', 1) + + def forward(self, inputs): + expert_input = torch.chunk(inputs, self.num_local_experts, dim=0) + expert_output = [] + + for i in range(self.num_local_experts): + expert_output.append(self.experts[i](expert_input[i])) + + output = torch.cat(expert_output, dim=0) + return output + + +class Top1Router(nn.Module): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More deailted function can be found in the paper about Switch Transformer + of Google. + """ + + def __init__(self, + capacity_factor: float, + min_capacity: int, + noisy_func=None): + super().__init__() + self.capacity_factor = capacity_factor + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, device=get_current_device())).rsample + + def get_capacity(self, logits_shape): + capacity = math.ceil(self.capacity_factor * + logits_shape[0] / logits_shape[1]) + if capacity < self.min_capacity: + capacity = self.min_capacity + return capacity + + def forward(self, inputs): + + if self.noisy_func is not None: + inputs_noisy = self.noisy_func(inputs) + else: + inputs_noisy = inputs + + logits = F.softmax(inputs, dim=1) + + num_experts = logits.shape[1] + capacity = self.get_capacity(logits.shape) + + expert_idx = torch.argmax(inputs_noisy, dim=1) + expert_mask = F.one_hot(expert_idx, num_classes=num_experts) + expert_mask_f = expert_mask.float() + + exp_counts = torch.sum(expert_mask, dim=0).detach().to('cpu') + + me = torch.mean(logits, dim=0) + ce = torch.mean(expert_mask_f, dim=0) + l_aux = torch.sum(me * ce) * num_experts + moe_env.add_loss(l_aux) + + rand_mask = expert_mask * self.uniform(logits.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + + dispatch_mask = \ + expert_mask * torch.zeros_like(expert_mask).scatter_(0, dispatch_idx, 1) + + locations = torch.cumsum(dispatch_mask, dim=0) - 1 + locations = torch.sum(dispatch_mask * locations, dim=1) + locations = F.one_hot(locations, num_classes=capacity) + + logits = logits * dispatch_mask + combine_weights = logits.unsqueeze(2) * locations.unsqueeze(1) + + sec_mask = combine_weights.bool() + return combine_weights, sec_mask, exp_counts + + +class Top2Router(nn.Module): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More deailted function can be found in the paper about ViT-MoE. + """ + + def __init__(self, capacity_factor: float, noisy_func=None): + super().__init__() + self.capacity_factor = capacity_factor + self.noisy_func = noisy_func + + def get_capacity(self, logits_shape): + capacity = math.ceil(2 * self.capacity_factor * + logits_shape[0] / logits_shape[1]) + return capacity + + def forward(self, inputs): + if self.noisy_func is not None: + inputs = self.noisy_func(inputs) + + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + _, expert_idx = torch.topk(logits, k=2, dim=-1, largest=True, sorted=True) + top1_idx = expert_idx[:, 0] + top2_idx = expert_idx[:, 1] + + mask1 = F.one_hot(top1_idx, num_classes=num_experts) + mask2 = F.one_hot(top2_idx, num_classes=num_experts) + + loss_mask = (mask1 + mask2) + exp_counts = torch.sum(loss_mask, dim=0).detach().to('cpu') + me = torch.mean(logits, dim=0) + ce = torch.mean(loss_mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 + moe_env.add_loss(l_aux) + + locations1 = torch.cumsum(mask1, dim=0) - 1 + locations2 = torch.cumsum(mask2, dim=0) - 1 + locations2 += torch.sum(mask1, dim=0, keepdim=True) + + mask1 *= torch.lt(locations1, capacity) + mask2 *= torch.lt(locations2, capacity) + + weight1 = mask1 * logits + weight2 = mask2 * logits + + locations1 = torch.sum(mask1 * locations1, dim=1) + locations2 = torch.sum(mask2 * locations2, dim=1) + locations1_sc = F.one_hot(locations1, num_classes=capacity) + locations2_sc = F.one_hot(locations2, num_classes=capacity) + + combine_weights1 = weight1.unsqueeze(2) * locations1_sc.unsqueeze(1) + combine_weights2 = weight2.unsqueeze(2) * locations2_sc.unsqueeze(1) + combine_weights = combine_weights1 + combine_weights2 + sec_mask = combine_weights.bool() + + return combine_weights, sec_mask, exp_counts + + +class MoeLayer(nn.Module): + """A MoE layer, that puts its input tensor to its gate and uses the output logits + to router all tokens, is mainly used to exchange all tokens for every expert across + the moe tensor group by all to all comunication. Then it will get the output of all + experts and exchange the output. At last returns the output of the moe system. + """ + + def __init__(self, + dim_model: int, + num_experts: int, + router: nn.Module, + experts: nn.Module): + super().__init__() + self.d_model = dim_model + self.num_experts = num_experts + self.gate = nn.Linear(dim_model, num_experts, device=get_current_device()) + self.router = router + self.experts = experts + + def _router_part(self, tokens: torch.Tensor): + gate_output = self.gate(tokens) + return self.router(gate_output) + + def router_part(self, tokens: torch.Tensor): + autocast_context = torch.is_autocast_enabled() + if not autocast_context: + return self._router_part(tokens) + else: + with autocast(enabled=False): + if tokens.dtype == torch.float16: + input_tokens = tokens.float() + else: + input_tokens = tokens + return self._router_part(input_tokens) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + tokens = inputs.reshape(-1, self.d_model) + + combine_weights, sec_mask, exp_counts = self.router_part(tokens) + + sec_mask_f = sec_mask.type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + + dispatch_data = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL) + + expert_output = self.experts(dispatch_data) + + expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL) + + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + + ret = torch.matmul(combine_weights, expert_output) + ret = ret.reshape(inputs.shape) + + return ret diff --git a/colossalai/nn/layer/vanilla/__init__.py b/colossalai/nn/layer/vanilla/__init__.py index 962c8e540..14af80027 100644 --- a/colossalai/nn/layer/vanilla/__init__.py +++ b/colossalai/nn/layer/vanilla/__init__.py @@ -1,3 +1,5 @@ -from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding +from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \ + WrappedDropout, WrappedDropPath -__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath'] +__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath', + 'WrappedDropout', 'WrappedDropPath'] diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py index a89e5e1e9..dc33c461e 100644 --- a/colossalai/nn/layer/vanilla/layers.py +++ b/colossalai/nn/layer/vanilla/layers.py @@ -10,6 +10,7 @@ from torch import Tensor, dtype from torch import nn as nn from ..utils import to_2tuple +from colossalai.context import seed def drop_path(x, drop_prob: float = 0., training: bool = False): @@ -42,6 +43,58 @@ class DropPath(nn.Module): return drop_path(x, self.drop_prob, self.training) +class WrappedDropout(nn.Module): + """Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. + """ + def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): + super().__init__() + if p < 0 or p > 1: + raise ValueError("dropout probability has to be between 0 and 1, " + "but got {}".format(p)) + self.p = p + self.inplace = inplace + if mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def normalfunc(self, inputs): + with seed(self.mode): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def forward(self, inputs): + return self.func(inputs) + + +class WrappedDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Here, it is wrapped with the context of seed manager. + """ + def __init__(self, p: float = 0., mode=None): + super().__init__() + self.p = p + self.mode = mode + if self.mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return drop_path(inputs, self.p, self.training) + + def normalfunc(self, inputs): + with seed(self.mode): + return drop_path(inputs, self.p, self.training) + + def forward(self, inputs): + return self.func(inputs) + + @LAYERS.register_module class VanillaPatchEmbedding(nn.Module): """ 2D Image to Patch Embedding diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 65eef4a9e..87f43bb84 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -6,6 +6,7 @@ from colossalai.nn.layer.utils import get_tensor_parallel_mode from .loss_2d import CrossEntropyLoss2D from .loss_2p5d import CrossEntropyLoss2p5D from .loss_3d import CrossEntropyLoss3D +from .loss_moe import MoeCrossEntropyLoss, MoeLoss _parallel_cross_entropy = { '2d': CrossEntropyLoss2D, diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py new file mode 100644 index 000000000..ebbc0e4c8 --- /dev/null +++ b/colossalai/nn/loss/loss_moe.py @@ -0,0 +1,34 @@ +import torch.nn as nn +from colossalai.registry import LOSSES +from torch.nn.modules.loss import _Loss +from colossalai.global_variables import moe_env + + +@LOSSES.register_module +class MoeCrossEntropyLoss(_Loss): + """torch.nn.CrossEntropyLoss added with auxiliary loss. + """ + def __init__(self, aux_weight: float = 0.01, *args, **kwargs): + super().__init__() + self.loss = nn.CrossEntropyLoss(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args): + main_loss = self.loss(*args) + aux_loss = moe_env.get_loss() + return main_loss + self.aux_weight * aux_loss + + +@LOSSES.register_module +class MoeLoss(_Loss): + """A wrapper class for any loss module to add with auxiliary loss. + """ + def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): + super().__init__() + self.loss_fn = loss_fn(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args, **kwargs): + main_loss = self.loss_fn(*args, **kwargs) + aux_loss = moe_env.get_loss() + return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index 0df30baab..d71b2a6d6 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -66,7 +66,7 @@ class CosineAnnealingWarmupLR(WarmupScheduler): :type last_epoch: int, optional """ - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1): + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1): base_scheduler = _CosineAnnealingLR( optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch) super().__init__(optimizer, warmup_steps, base_scheduler) diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 3d64a7b6f..a93818b07 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -17,6 +17,7 @@ import torch.distributed as dist from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.global_variables import moe_env from .multi_tensor_apply import multi_tensor_applier @@ -91,6 +92,10 @@ def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) +def is_moe_parallel_parameter(p): + return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1 + + def _calc_l2_norm(grads): norm = 0.0 if len(grads) > 0: @@ -165,26 +170,37 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): else: tensor_parallel_grads = [] no_tensor_parallel_grads = [] + moe_parallel_grads = [] # used to collect moe tensor parallel gradients for p in params: if is_model_parallel_parameter(p): reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) tensor_parallel_grads.append(p.grad.data / reductor) + elif is_moe_parallel_parameter(p): + moe_parallel_grads.append(p.grad.data) else: no_tensor_parallel_grads.append(p.grad.data) + if norm_type == 2.0: tensor_parallel_norm = _calc_l2_norm( tensor_parallel_grads) ** norm_type no_tensor_parallel_norm = _calc_l2_norm( no_tensor_parallel_grads) ** norm_type + moe_parallel_norm = _calc_l2_norm( + moe_parallel_grads) ** norm_type else: tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) - no_tensor_parallel_grads = _calc_lp( + no_tensor_parallel_norm = _calc_lp( no_tensor_parallel_grads, norm_type) + moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type) # Sum across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + # Sum across all moe-tensor-parallel GPUs + if len(moe_parallel_grads) > 0: + dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL)) + no_tensor_parallel_norm += moe_parallel_norm total_norm = tensor_parallel_norm + no_tensor_parallel_norm if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: dist.all_reduce(total_norm, diff --git a/model_zoo/helper.py b/model_zoo/helper.py new file mode 100644 index 000000000..0f4fac17c --- /dev/null +++ b/model_zoo/helper.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn +from colossalai.nn.layer import WrappedDropPath as DropPath + + +class TransformerLayer(nn.Module): + """Transformer layer builder. + """ + def __init__(self, + att: nn.Module, + ffn: nn.Module, + norm1: nn.Module, + norm2: nn.Module, + droppath=None, + droppath_rate: float = 0): + super().__init__() + self.att = att + self.ffn = ffn + self.norm1 = norm1 + self.norm2 = norm2 + self.droppath = DropPath(droppath_rate) if droppath is None else droppath + + def forward(self, x): + x = x + self.droppath(self.att(self.norm1(x))) + x = x + self.droppath(self.ffn(self.norm2(x))) + return x diff --git a/model_zoo/moe/__init__.py b/model_zoo/moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/model_zoo/moe/models.py b/model_zoo/moe/models.py new file mode 100644 index 000000000..eb1db9caa --- /dev/null +++ b/model_zoo/moe/models.py @@ -0,0 +1,146 @@ +import math +import torch +import torch.nn as nn +from colossalai.context import ParallelMode +from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ + WrappedDropout as Dropout, WrappedDropPath as DropPath +from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator +from .util import moe_sa_args, moe_mlp_args +from ..helper import TransformerLayer +from colossalai.global_variables import moe_env +from colossalai.utils import get_current_device + + +class VanillaSelfAttention(nn.Module): + """Standard ViT self attention. + """ + def __init__(self, + d_model: int, + n_heads: int, + d_kv: int, + attention_drop: float = 0, + drop_rate: float = 0, + bias: bool = True, + dropout1=None, + dropout2=None): + super().__init__() + self.n_heads = n_heads + self.d_kv = d_kv + self.scale = 1.0 / math.sqrt(self.d_kv) + + self.dense1 = nn.Linear(d_model, 3 * n_heads * d_kv, bias, device=get_current_device()) + self.softmax = nn.Softmax(dim=-1) + self.atten_drop = nn.Dropout(attention_drop) if dropout1 is None else dropout1 + self.dense2 = nn.Linear(n_heads * d_kv, d_model, device=get_current_device()) + self.dropout = nn.Dropout(drop_rate) if dropout2 is None else dropout2 + + def forward(self, x): + qkv = self.dense1(x) + new_shape = qkv.shape[:2] + (3, self.n_heads, self.d_kv) + qkv = qkv.view(*new_shape) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[:] + + x = torch.matmul(q, k.transpose(-2, -1)) * self.scale + x = self.atten_drop(self.softmax(x)) + + x = torch.matmul(x, v) + x = x.transpose(1, 2) + new_shape = x.shape[:2] + (self.n_heads * self.d_kv,) + x = x.reshape(*new_shape) + x = self.dense2(x) + x = self.dropout(x) + + return x + + +class VanillaFFN(nn.Module): + """FFN composed with two linear layers, also called MLP. + """ + def __init__(self, + d_model: int, + d_ff: int, + activation=None, + drop_rate: float = 0, + bias: bool = True, + dropout1=None, + dropout2=None): + super().__init__() + dense1 = nn.Linear(d_model, d_ff, bias, device=get_current_device()) + act = nn.GELU() if activation is None else activation + dense2 = nn.Linear(d_ff, d_model, bias, device=get_current_device()) + drop1 = nn.Dropout(drop_rate) if dropout1 is None else dropout1 + drop2 = nn.Dropout(drop_rate) if dropout2 is None else dropout2 + + self.ffn = nn.Sequential( + dense1, act, drop1,dense2, drop2) + + def forward(self, x): + return self.ffn(x) + + +class Widenet(nn.Module): + def __init__(self, + num_experts: int, + capacity_factor: float, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + depth: int = 12, + d_model: int = 768, + num_heads: int = 12, + d_kv: int = 64, + d_ff: int = 3072, + attention_drop: float = 0., + drop_rate: float = 0.1, + drop_path: float = 0.): + super().__init__() + + embedding = VanillaPatchEmbedding( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_size=d_model) + embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) + + shared_sa = VanillaSelfAttention(**moe_sa_args( + d_model=d_model, n_heads=num_heads, d_kv=d_kv, + attention_drop=attention_drop, drop_rate=drop_rate)) + + noisy_func = NormalNoiseGenerator(num_experts) + shared_router = Top2Router(capacity_factor, noisy_func=noisy_func) + shared_experts = Experts(expert=VanillaFFN, + num_experts=num_experts, + **moe_mlp_args( + d_model=d_model, + d_ff=d_ff, + drop_rate=drop_rate + )) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] + blocks = [ + TransformerLayer( + att=shared_sa, + ffn=MoeLayer(dim_model=d_model, num_experts=num_experts, + router=shared_router, experts=shared_experts), + norm1=nn.LayerNorm(d_model, eps=1e-6), + norm2=nn.LayerNorm(d_model, eps=1e-6), + droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR) + ) + for i in range(depth) + ] + norm = nn.LayerNorm(d_model, eps=1e-6) + self.linear = VanillaClassifier(in_features=d_model, + num_classes=num_classes) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm) + + def forward(self, x): + moe_env.reset_loss() + x = self.widenet(x) + x = torch.mean(x, dim=1) + x = self.linear(x) + return x diff --git a/model_zoo/moe/util.py b/model_zoo/moe/util.py new file mode 100644 index 000000000..60028656e --- /dev/null +++ b/model_zoo/moe/util.py @@ -0,0 +1,41 @@ +from colossalai.context import ParallelMode +from colossalai.nn.layer import WrappedDropout as Dropout + + +def moe_sa_args(d_model: int, + n_heads: int, + d_kv: int, + attention_drop: float = 0, + drop_rate: float = 0, + bias: bool = True): + """This is an example for args in moe self attention, since lots of modules should be + adapted before putting them in experts. + """ + dropout1 = Dropout(attention_drop, mode=ParallelMode.TENSOR) + dropout2 = Dropout(drop_rate, mode=ParallelMode.TENSOR) + return dict( + d_model=d_model, + n_heads=n_heads, + d_kv=d_kv, + bias=bias, + dropout1=dropout1, + dropout2=dropout2 + ) + + +def moe_mlp_args(d_model: int, + d_ff: int, + drop_rate: float, + bias: bool = True): + """This is an example for args of MLP in Experts, since lots of modules should be adapted + before putting them in experts. + """ + dropout1 = Dropout(drop_rate, mode=ParallelMode.TENSOR) + dropout2 = Dropout(drop_rate, mode=ParallelMode.TENSOR) + return dict( + d_model=d_model, + d_ff=d_ff, + bias=bias, + dropout1=dropout1, + dropout2=dropout2 + )