[moe] merge moe into main (#4978)

* update moe module
* support openmoe
pull/4926/head
Xuanlei Zhao 2023-11-02 10:21:24 +08:00 committed by GitHub
parent 8993c8a817
commit dc003c304c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 7618 additions and 1657 deletions

View File

@ -0,0 +1,382 @@
import random
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelAMPOptimizer,
HybridParallelModule,
HybridParallelNaiveOptimizer,
HybridParallelPlugin,
get_param_info,
init_pipeline_optimizer,
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoeCheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer
PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(
optimizer=optimizer,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
clip_grad_norm=clip_grad_norm,
verbose=verbose,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
moe_extra_dp_process_group=moe_extra_dp_process_group,
)
class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
Plugin for Moe Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
"""
def __init__(
self,
tp_size: int,
pp_size: int,
extra_dp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
# sync moe in outer dp group, and sync other param in global dp group
if extra_dp_size > 1:
ep_size = self.dp_size // extra_dp_size
if use_ep_inside:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
else:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
else:
self.moe_extra_dp_group = None
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
)
self.ddp_config = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.zero_config = dict(
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
)
self.max_norm = max_norm
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def get_checkpoint_io(self) -> MoeCheckpintIO:
self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def configure(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_extra_dp_process_group=self.moe_extra_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config,
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,7 +1,5 @@
from .config import Config, ConfigException
# from .moe_context import MOE_CONTEXT
__all__ = [
"Config",
"ConfigException",

View File

@ -1,132 +0,0 @@
from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.legacy.tensor import ProcessGroup
def _check_sanity():
from colossalai.legacy.core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups."""
def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
self.ep_size = ep_size
self.dp_size = dp_size
self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
self.ep_group = self.pg.tp_process_group()
self.dp_group = self.pg.dp_process_group()
class MoeContext(metaclass=SingletonMeta):
"""MoE parallel context manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
def __init__(self):
self.world_size = 1
# Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = 1
self.min_dp_size = 1
self.aux_loss = None
self.use_kernel_optim = True
self.has_setup = False
self._parallel_info_dict = dict()
@property
def parallel_info_dict(self):
return self._parallel_info_dict
@property
def is_initialized(self):
return self.has_setup
def setup(self, seed: int, use_kernel_optim: bool = True):
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
_check_sanity()
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.world_size = dist.get_world_size()
from colossalai.legacy.core import global_context as gpc
self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
assert (
self.world_size % self.max_ep_size == 0
), "Maximum expert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases
# Users can close kernel optimization manually
self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed
moe_set_seed(seed)
self.has_setup = True
def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
"""Calculate the Data Parallel Group and Expert Parallel Group.
Parameters
----------
num_experts : int
The number experts
Returns
-------
int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size
"""
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, (
"Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
# So it's data parallel size is 1
# Otherwise, there is only one expert in each GPU
# The data parallel size should be calculated
dp_size = 1 if gt_flag else self.max_ep_size // num_experts
ep_size = self.max_ep_size // dp_size
# Calculate the number of experts for each GPU
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
# Don't forget to multiply minimum data parallel size
dp_size *= self.min_dp_size
if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
return num_local_experts, self.parallel_info_dict[ep_size]
def set_kernel_not_use(self):
self.use_kernel_optim = False
def reset_loss(self):
self.aux_loss = 0
def add_loss(self, loss):
self.aux_loss += loss
def get_loss(self):
return self.aux_loss
MOE_CONTEXT = MoeContext()

View File

@ -0,0 +1,185 @@
from functools import reduce
from typing import Any, Tuple
import torch
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
PRECISION_MAP = {
"fp32": (0, torch.float32),
"fp16": (1, torch.float16),
"bf16": (2, torch.bfloat16),
}
@triton.jit
def _llama_act_combine_forward(
X_GATE1,
X_GATE2,
X_UP,
Y,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
Y += row * stride
# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
# Write output
tl.store(Y + cols, y, mask=mask)
@triton.jit
def _llama_act_combine_backward(
X_GATE1,
X_GATE2,
X_UP,
X_GATE1_GRAD,
X_GATE2_GRAD,
X_UP_GRAD,
Y_GRAD,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
X_GATE1_GRAD += row * stride
X_GATE2_GRAD += row * stride
X_UP_GRAD += row * stride
Y_GRAD += row * stride
# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid
x_up_grad = x_gate2_act * x_gate1
x_gate1_grad = x_gate2_act * x_up
# grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 sigmoid(x)]
# = sigmoid(x) * {1 + x * [(1 sigmoid(x)]}
x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid))
# Write output
tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask)
tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask)
tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask)
class LlamaActCombine(torch.autograd.Function):
"""
act(x_gate) * x_up
Args:
x_gate (torch.Tensor): (b, l, 2d) x_gate
x_up (torch.Tensor): (b, l, d) x_up
activation (str): only support swiglu
precision (str): fp32, fp16, bf16
"""
@staticmethod
@custom_fwd
def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor:
"""
act(x_gate) * x_up
Args:
x_gate (torch.Tensor): (b, l, 2d) x gate
x_up (torch.Tensor): (b, l, d) x up
activation (str): only support swiglu
"""
assert activation == "swiglu", "Only swiglu is supported"
# split x gate
assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2"
x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1)
x_gate1 = x_gate1.contiguous()
x_gate2 = x_gate2.contiguous()
if not x_up.is_contiguous():
x_up = x_up.contiguous()
# assert shape
assert x_gate1.shape == x_gate2.shape == x_up.shape
# add ctx for backward
if x_gate.requires_grad:
ctx.save_for_backward(x_gate1, x_gate2, x_up)
# allocate output
y = torch.empty_like(x_up)
M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1]
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x_gate.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# restore setting
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
# enqueue kernel
_llama_act_combine_forward[(M,)](x_gate1,
x_gate2,
x_up,
y,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
return y
@staticmethod
@custom_bwd
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]:
# restore from ctx
(x_gate1, x_gate2, x_up) = ctx.saved_tensors
M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps
# init grad
y_grad = grad_outputs[0]
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
x_gate2), torch.empty_like(x_up)
# enqueue kernel
_llama_act_combine_backward[(M,)](x_gate1,
x_gate2,
x_up,
x_gate1_grad,
x_gate2_grad,
x_up_grad,
y_grad,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
return x_gate_grad, x_up_grad, None, None

View File

@ -1,6 +1,5 @@
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._moe_gradient_handler import MoeGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
@ -10,6 +9,5 @@ __all__ = [
"DataParallelGradientHandler",
"ZeROGradientHandler",
"PipelineSharedModuleGradientHandler",
"MoeGradientHandler",
"SequenceParallelGradientHandler",
]

View File

@ -16,7 +16,6 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from colossalai.context import Config, ConfigException
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.interface import OptimizerWrapper
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
@ -36,7 +35,6 @@ from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.utils.moe import sync_moe_model_param
def get_default_parser():
@ -323,8 +321,6 @@ def initialize(
if not use_zero:
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif MOE_CONTEXT.is_initialized:
sync_moe_model_param(model)
elif is_using_ddp():
sync_model_param(model, ParallelMode.DATA)
else:
@ -377,14 +373,6 @@ def initialize(
"added even though not specified in the configuration",
ranks=[0],
)
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
gradient_handler_cfg = [dict(type="MoeGradientHandler")]
if verbose:
logger.info(
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0],
)
elif is_using_sequence():
model = DDP(
model,

View File

@ -0,0 +1,17 @@
from .checkpoint import MoeCheckpintIO
from .experts import MLPExperts
from .layers import SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
__all__ = [
"MLPExperts",
"MoeRouter",
"Top1Router",
"Top2Router",
"TopKRouter",
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoeCheckpintIO",
]

View File

@ -0,0 +1,275 @@
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from colossalai.moe.manager import MOE_MANAGER
MOE_KERNEL = None
def load_moe():
global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder
MOE_KERNEL = MOEBuilder().load()
class AllGather(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0), None
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
if not overlap:
dist.all_gather(buffer_list, inputs, group=group)
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0), None
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
if not overlap:
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
# TODO: support async backward
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if dist.get_world_size(group) == 1:
return inputs, None
output = torch.empty_like(inputs)
if not overlap:
dist.all_to_all_single(output, inputs, group=group)
return output, None
else:
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
return output, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0],
None,
None,
)
class MoeDispatch(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
dtype = tokens.dtype
if MOE_KERNEL is None:
load_moe()
if tokens.dtype != torch.float32:
tokens = tokens.to(torch.float32)
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
if expert_input.dtype != dtype:
expert_input = expert_input.to(dtype)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
ctx.dtype = dtype
return expert_input
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
if output_grad.dtype != torch.float32:
output_grad = output_grad.to(torch.float32)
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
if d_tokens.dtype != ctx.dtype:
d_tokens = d_tokens.to(ctx.dtype)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
dtype = expert_tokens.dtype
if expert_tokens.dtype != torch.float32:
expert_tokens = expert_tokens.to(torch.float32)
if MOE_KERNEL is None:
load_moe()
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
if output.dtype != dtype:
output = output.to(dtype)
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.dtype = dtype
return output
@staticmethod
@custom_bwd
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32)
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
mask, dest_idx)
if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype)
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and use_kernel:
if MOE_KERNEL is None:
load_moe()
return MOE_KERNEL.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
class MoeInGradScaler(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None:
ctx.ep_size = ep_size
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad * ctx.ep_size
return grad, None
class MoeOutGradScaler(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
ctx.ep_size = ep_size
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None

View File

@ -0,0 +1,274 @@
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict_into_model,
save_config_file,
save_state_dict_shards,
)
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
class MoeCheckpintIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
) -> None:
assert zero_stage in [
0,
1,
2,
], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
super().__init__(dp_group, pp_group, tp_group, zero_stage)
self.parallel = MOE_MANAGER.parallel
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
for name, param in state_dict.items():
if ".experts." in name:
if name in dict(model.named_parameters()):
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
state_dict[name] = param
dist.barrier()
return state_dict
def _model_sharder(
self,
state_dict: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
for name, param in state_dict.items():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
state_dict = torch.load(checkpoint)
state_dict = self.pre_load_model(model, state_dict)
model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
state_dict = self.pre_load_model(model, state_dict)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name and is_moe_tensor(param):
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [deepcopy(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.all_gather_object(out, state_dict, group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
new_state_dict.update(o)
state_dict = new_state_dict
dist.barrier()
return state_dict
def save_unsharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
dist.barrier()
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
The filenames are in the form of "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
dist.barrier()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
raise NotImplementedError()
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
raise NotImplementedError()
def save_sharded_optimizer(
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
):
raise NotImplementedError()
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
raise NotImplementedError()

156
colossalai/moe/experts.py Normal file
View File

@ -0,0 +1,156 @@
import math
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
class MLPExperts(nn.Module):
"""
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
Args:
num_experts (int): The number of experts
hidden_size (int): The hidden size of MLP
intermediate_size (int): The intermediate size of MLP
expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
activation (optional): The activation function of MLP
drop_rate (float, optional): The drop rate of MLP
gated (bool, optional): Whether to use gated MLP
use_kernel (bool, optional): Whether to use kernel optimization
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
expert_parallel: Optional[str] = None,
activation: Optional[Callable] = None,
drop_rate: Optional[float] = 0,
gated: Optional[bool] = False,
use_kernel: Optional[bool] = False,
):
super().__init__()
assert expert_parallel in ["EP", "TP", None]
self.expert_parallel = expert_parallel
self.num_total_experts = num_experts
self.gated = gated
self.use_kernel = use_kernel
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# get expert parallel info
if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False)
# get settings for different parallel
self.ep_size = get_ep_size(self)
if expert_parallel == "TP":
intermediate_size = intermediate_size // self.ep_size
num_experts = self.num_total_experts
else:
num_experts = self.num_local_experts
else:
self.num_local_experts = self.num_total_experts
self.ep_size = 1
if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
self.act_name = activation
self.act = get_activation(activation)
self.drop = nn.Dropout(p=drop_rate)
if expert_parallel is not None:
for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)
# init param
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
# expert param should be different
if self.expert_parallel is not None:
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True)
else:
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
with seed_ctx:
if self.gated:
torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
else:
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
def forward(
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
) -> torch.Tensor:
"""
forward: hidden_size --> intermediate_size --> hidden_size
Args:
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
x = MoeInGradScaler.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
x = x.transpose(0, 1)
inshape = x.shape
x = x.reshape(e, -1, h)
if self.use_kernel and use_sparse:
seq_len = x.shape[1]
with torch.no_grad():
mask = x[:, :, 0] != 0.0
mask = torch.sum(mask, dim=-1)
x_list = []
for i in range(e):
x_list.append(x[i, :mask[i]])
x = x_list
if self.gated:
x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
else:
x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
else:
x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
x = [self.act(x[i]) for i in range(e)]
x = [self.drop(x[i]) for i in range(e)]
x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
if self.use_kernel and use_sparse:
for i in range(e):
x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
return x

361
colossalai/moe/layers.py Normal file
View File

@ -0,0 +1,361 @@
import dataclasses
import math
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import MLPExperts
from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
class SparseMLP(nn.Module):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in `Switch Transformer paper`_.
'Gaussian' can be found in `ViT-MoE paper`_.
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in `Microsoft paper`_.
residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
.. _Switch Transformer paper:
https://arxiv.org/abs/2101.03961
.. _ViT-MoE paper:
https://arxiv.org/abs/2106.05974
.. _Microsoft paper:
https://arxiv.org/abs/2201.05596
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
router_capacity_factor_train: Optional[float] = 1.25,
router_capacity_factor_eval: Optional[float] = 2.0,
router_min_capacity: Optional[int] = 4,
router_noisy_policy: Optional[str] = None,
router_drop_tks: Optional[bool] = True,
mlp_activation: Optional[str] = None,
mlp_gated: Optional[bool] = False,
enable_load_balance: Optional[bool] = False,
load_balance_tolerance: Optional[float] = 0.1,
load_balance_beam_width: Optional[int] = 8,
load_balance_group_swap_factor: Optional[float] = 0.4,
enable_kernel: Optional[bool] = False,
enable_comm_overlap: Optional[bool] = False,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.gated = mlp_gated
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
# moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
router_cls = get_router_cls(router_top_k)
self.topk = router_top_k
self.router: MoeRouter = router_cls(
capacity_factor_train=router_capacity_factor_train,
capacity_factor_eval=router_capacity_factor_eval,
min_capacity=router_min_capacity,
noisy_func=noisy_func,
drop_tks=router_drop_tks,
)
# gate
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
# moe experts
self.experts = MLPExperts(
num_experts=self.num_experts,
expert_parallel=self.expert_parallel,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
activation=mlp_activation,
gated=mlp_gated,
use_kernel=self.enable_kernel,
)
# get parallel settings
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.dp_group = get_dp_group(self.experts)
else:
self.ep_group = None
self.dp_group = None
self.num_local_experts = self.experts.num_local_experts
# load balance
self.enable_load_balance = enable_load_balance
if self.enable_load_balance == True:
self.load_balancer = LoadBalancer(
experts=self.experts,
gate=self.gate_weight,
local_expert_num=self.num_local_experts,
expert_num=self.num_experts,
ep_group=self.ep_group,
dp_group=self.dp_group,
tolerance=load_balance_tolerance,
beam_width=load_balance_beam_width,
group_swap_factor=load_balance_group_swap_factor,
)
# init param
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
Returns:
torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size)
"""
# reshape the input tokens
tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# update expert load
if self.enable_load_balance == True:
with torch.no_grad():
# TODO: optimize computation
expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1]
# TODO: bincount introduces synchronize, fix it
expert_load = torch.bincount(expert_load.view(-1))
self.load_balancer.update_load(expert_load)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size)
else:
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
"Please use Experts build function.")
if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size)
ans = MoeCombine.apply(expert_output, *route_result_list)
else:
combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0)
expert_out = self.experts(expert_in)
return expert_out
def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
"""
Expert Parallel
Args:
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
return expert_output
else:
@dataclasses.dataclass
class Capsule:
data: torch.Tensor
handle: Any = None
NUM_CHUNK = 4
NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape)
chunk_data = torch.split(dispatch_data, chunk_size, dim=2)
output = torch.empty_like(dispatch_data)
offset = 0
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data
offset += chunk_size
expert_out = None
# all2all last output
if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
_expert_out = None
# all2all next input
if 0 <= i < NUM_CHUNK:
_expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True))
# compute
if expert_in is not None:
expert_in.handle.wait()
_expert_out = Capsule(data=self.experts(expert_in.data), handle=None)
expert_in = None
if _expert_in is not None:
expert_in = _expert_in
_expert_in = None
return output
def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
"""
without overlap:
| C |
| A | | R |
with overlap:
| C1 || C2 || C3 || C4 |
| A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 |
where C is computation, A is all gather, R is reduce scatter.
Args:
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0]
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]
return expert_out
else:
@dataclasses.dataclass
class Capsule:
data: torch.Tensor
handle: Any
indices: Tuple
NUM_CHUNK = 4
NUM_STAGES = 4
assert (dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)
def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]:
return (slice(idx * chunk_size, (idx + 1) * chunk_size),)
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[expert_out.indices] = expert_out.data
expert_out = None
# reduce scatter last output
if _expert_out is not None:
expert_out = Capsule(
*ReduceScatter.apply(_expert_out.data, self.ep_group, True),
indices=_expert_out.indices,
)
_expert_out = None
# all gather next input
if 0 <= i < NUM_CHUNK:
_expert_in = Capsule(
*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True),
indices=get_chunk_slice(i, chunk_size),
)
# compute
if expert_in is not None:
expert_in.handle.wait()
_expert_out = Capsule(
self.experts(expert_in.data, expert_in.indices),
handle=None,
indices=expert_in.indices,
)
expert_in = None
if _expert_in is not None:
expert_in = _expert_in
_expert_in = None
return output
def apply_load_balance(model: nn.Module, optim: Any) -> None:
"""
apply load balance to every experts in the model
"""
def _apply_recursive(module: nn.Module):
for _, sub_module in module.named_children():
if isinstance(sub_module, SparseMLP):
if sub_module.enable_load_balance == True:
sub_module.load_balancer.balance_load(optim)
_apply_recursive(sub_module)
torch.cuda.empty_cache()
_apply_recursive(model)
torch.cuda.empty_cache()

View File

@ -0,0 +1,442 @@
from copy import deepcopy
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor, nn
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
from colossalai.moe.experts import MLPExperts
from colossalai.moe.manager import MOE_MANAGER
from colossalai.zero.low_level import LowLevelZeroOptimizer
class LoadBalancer:
def __init__(
self,
experts: MLPExperts,
gate: nn.Parameter,
local_expert_num: int,
expert_num: int,
ep_group: ProcessGroup,
dp_group: ProcessGroup,
tolerance: Optional[float] = 0.1,
beam_width: Optional[int] = 8,
group_swap_factor: Optional[float] = 0.4,
) -> None:
self.experts: MLPExperts = experts
self.gate: nn.Parameter = gate
self.moe_ep_group: ProcessGroup = ep_group
self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks
self.moe_dp_group: ProcessGroup = dp_group
self.tolerance = tolerance
self.beam_width = beam_width
self.group_swap_factor = group_swap_factor
self.local_expert_num = local_expert_num
self.expert_num = expert_num
self.local_load = None
# TODO: use a global process group mesh
pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size
global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size)
self.global_dp_group = global_dp_group.get_group_along_axis(1)
self.global_dp_rank = dist.get_rank(self.global_dp_group)
self.global_dp_size = dist.get_world_size(self.global_dp_group)
def _clear_load(self) -> None:
self.local_load = None
def _sync_load(self) -> Tensor:
new_load = self.local_load.clone().detach()
# all reduce load between ep group
dist.all_reduce(new_load, group=self.moe_ep_group)
# all reduce load between dp group
dist.all_reduce(new_load, group=self.moe_dp_group)
return new_load
@staticmethod
def _get_diff_from_avg(data: List, group: int, avg: float) -> float:
return abs(sum(data[group]) / len(data[group]) - avg)
@staticmethod
def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None:
data[group_i][index_i], data[group_j][index_j] = (
data[group_j][index_j],
data[group_i][index_i],
)
@staticmethod
def _normalize_data(data: List) -> List:
max_value = max(max(sublist) for sublist in data)
data = [[i / max_value for i in sublist] for sublist in data]
return data
@staticmethod
def _get_swap_loss(
group_swap_factor: float,
swap_list: List,
group_i: int,
index_i: int,
group_j: int,
index_j: int,
) -> float:
"""
Get swap loss. The swap loss is used to avoid the situation that
the same index is swapped twice and the same group is swapped for multiple times.
"""
swap_loss = 0
for swap in swap_list:
for group_id, index_id in zip([group_i, group_j], [index_i, index_j]):
# the group has been swapped
if group_id in [swap[0], swap[2]]:
# the index has been swapped
# we want to avoid the situation that the same index is swapped twice
if index_id in [swap[1], swap[3]]:
swap_loss += 1e5
# the index has not been swapped
# this is acceptable but as less as possible
else:
swap_loss += group_swap_factor
return swap_loss
@staticmethod
def _check_convergence(data: List, avg: float, tolerance: float):
"""
Check whether the data is converged after swap.
"""
for sublist in data:
if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg:
return False
return True
def _beam_search(
self,
inputs: Tuple[List, float, List],
beam_width: int,
avg: float,
group_swap_factor: float,
) -> List:
"""
Beam search for the best swap combination.
Specifically, we swap two elements from two groups and calculate the score.
The score is the difference between the origin group sum and the new group sum.
The larger the score, the better the swap combination.
Args:
inputs (Tuple): (data, origin_score, swap_list)
beam_width (int): beam width for beam search
avg (float): average value of the data
group_swap_factor (float): group loss for group swap loss
Returns:
List: results list
"""
data, origin_score, swap_list = inputs
results = []
group_num = len(data)
group_size = len(data[0])
origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)]
for group_num_i in range(group_num):
for group_size_i in range(group_size):
for group_num_j in range(group_num_i + 1, group_num):
for group_size_j in range(group_size):
new_data = deepcopy(data)
# calculate origin group sum
origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j]
# swap data
self._swap_data(
new_data,
group_num_i,
group_size_i,
group_num_j,
group_size_j,
)
# calculate new group sum
new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg(
new_data, group_num_j, avg
)
# caculate score
new_score = origin_diff - new_diff
if new_score > 0:
new_score = origin_score + new_score
# get swap loss
swap_loss = self._get_swap_loss(
group_swap_factor,
swap_list,
group_num_i,
group_size_i,
group_num_j,
group_size_j,
)
new_score = new_score - swap_loss
# update swap list
new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)]
results.append((new_data, new_score, new_swap_list))
# sort results
results.sort(key=lambda x: x[1], reverse=True)
# select top k results
results = results[:beam_width]
return results
def _load_to_list(self, load: Tensor) -> List:
load_len = len(load)
assert load_len % self.local_expert_num == 0
load_list = []
tmp_list = []
for i in range(len(load)):
tmp_list.append(float(load[i]))
if (i + 1) % self.local_expert_num == 0:
load_list.append(tmp_list)
tmp_list = []
return load_list
def _search_balance(
self,
data: List,
tolerance: Optional[float] = 0.1,
beam_width: Optional[int] = 8,
group_swap_factor: Optional[float] = 0.4,
return_swapped_data: Optional[bool] = False,
) -> Tuple[List, List]:
"""
Search for the best swap combination to balance the data within the specified tolerance.
And return the balanced data and the swap list. The swap list is used to record the swap.
The swap list is a list of tuples. Each tuple is a swap operation.
Args:
data (List): expert load list.
E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]]
This means there are 4 devices and each devices has 2 experts.
The value is the load of the expert.
tolerance (float): tolerance for balance.
beam_width (int): beam width for beam search.
group_swap_factor (float): group swap factor for group swap loss.
The bigger it is, the less times a group will be swapped.
return_swapped_data (bool): whether to return the swapped data.
Returns:
Tuple: (balanced data, swap list).
The swap list is a list of tuples. Each tuple is a swap operation.
E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means
the first expert of the first device is swapped with the first expert
of the second device.
"""
norm_data = self._normalize_data(data)
avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data)
results = [(norm_data, 0, [])]
stop_flag = False
while stop_flag == False:
new_results = []
best_score = results[0][1]
for i in range(len(results)):
new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor))
if len(new_results) == 0:
stop_flag = True
break
new_results.sort(key=lambda x: x[1], reverse=True)
new_best_score = new_results[0][1]
if new_best_score == best_score:
stop_flag = True
break
new_results = new_results[:beam_width]
results = new_results
for i in results:
if self._check_convergence(results[0][0], avg, tolerance):
stop_flag = True
break
swap_list = results[0][2]
if return_swapped_data:
out = deepcopy(data)
for swap in swap_list:
self._swap_data(out, *swap)
return out, swap_list
else:
return swap_list
@staticmethod
def _swap_expert_single_tensor(
weight: nn.Parameter,
expert_idx: int,
comm_group: ProcessGroup,
send_first: bool,
comm_rank: int,
):
# exchange weight
local_weight = weight.data[expert_idx]
new_weight = torch.empty_like(local_weight)
if send_first:
dist.send(local_weight, dst=comm_rank, group=comm_group)
dist.recv(new_weight, src=comm_rank, group=comm_group)
else:
dist.recv(new_weight, src=comm_rank, group=comm_group)
dist.send(local_weight, dst=comm_rank, group=comm_group)
weight.data[expert_idx] = new_weight
def _swap_expert_param_and_optim(
self,
weight: nn.Parameter,
expert_idx: int,
comm_group: ProcessGroup,
send_first: bool,
comm_rank: int,
optim: LowLevelZeroOptimizer,
):
# need to update master and working param if master param exists
# else just update working param
if weight in optim.optim.state:
master_weight_ptr = None
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
else:
master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
# exchange weight
self._swap_expert_single_tensor(
working_weight_ptr,
expert_idx,
comm_group,
send_first,
comm_rank,
)
if master_weight_ptr is not None:
# TODO: exchange master weight, skip for now
# master weight is shared by dp group
tmp = working_weight_ptr.view(-1).split(
working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group)
)[dist.get_rank(self.moe_dp_group)]
master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype))
# exchange optim
self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank)
self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank)
def _gather_global_dp_group(self, data: Tensor) -> Tensor:
data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)]
dist.all_gather(data_list, data, group=self.global_dp_group)
data_list = torch.cat(data_list, dim=0)
return data_list
def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None:
"""
Swap moe param and optim.
We use different strategies to swap expert and gate.
For expert, we exchange the param and optim of the expert by p2p.
For gate, we all gather the gate choose the part we want.
Args:
swap_list (List)
optim (LowLevelZeroOptimizer)
"""
# get all experts weights
local_rank = dist.get_rank(self.moe_ep_group)
if self.experts.gated:
weight_list = [self.experts.wi_up, self.experts.wi_gate]
else:
weight_list = [self.experts.wi]
weight_list.append(self.experts.wo)
# gate optim should be obtained first
gate_shape = self.gate.shape
# get master weight and optim
master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
# gather
global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape)
global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape)
global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape)
assert (
self.gate.shape
== global_master_gate_weight.shape
== global_gate_exp_avg.shape
== global_gate_exp_avg_sq.shape
)
for swap in swap_list:
source_group, source_idx, target_group, target_idx = swap
source_rank = self.moe_ep_ranks[source_group]
target_rank = self.moe_ep_ranks[target_group]
# exchange expert
if local_rank in [source_group, target_group]:
for weight in weight_list:
if local_rank == source_group:
self._swap_expert_param_and_optim(
weight,
source_idx,
self.moe_ep_group,
True,
target_rank,
optim,
)
elif local_rank == target_group:
self._swap_expert_param_and_optim(
weight,
target_idx,
self.moe_ep_group,
False,
source_rank,
optim,
)
# exchange gate
source_expert_pos = source_group * self.local_expert_num + source_idx
target_expert_pos = target_group * self.local_expert_num + target_idx
for gate in [
self.gate,
global_master_gate_weight,
global_gate_exp_avg,
global_gate_exp_avg_sq,
]:
origin_source = gate.data[source_expert_pos].clone().detach()
origin_target = gate.data[target_expert_pos].clone().detach()
gate.data[source_expert_pos], gate.data[target_expert_pos] = (
origin_target,
origin_source,
)
# update gate
global_master_gate_weight = global_master_gate_weight.view(-1).split(
global_master_gate_weight.numel() // self.global_dp_size
)[self.global_dp_rank]
master_gate_weight.data.copy_(global_master_gate_weight)
global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[
self.global_dp_rank
]
gate_exp_avg.data.copy_(global_gate_exp_avg)
global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(
global_gate_exp_avg_sq.numel() // self.global_dp_size
)[self.global_dp_rank]
gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq)
@torch.no_grad()
def update_load(self, load: Tensor) -> None:
if len(load) != self.expert_num:
padding_size = self.expert_num - len(load)
padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device)
load = torch.cat((load, padding), dim=0)
if self.local_load is None:
self.local_load = load
else:
self.local_load += load
@torch.no_grad()
def balance_load(self, optim: LowLevelZeroOptimizer) -> None:
# prepare load
load = self._sync_load()
load = self._load_to_list(load)
# search balance
swap_list = self._search_balance(load)
if dist.get_rank() == 0:
if len(swap_list) > 0:
print(f"[Load Balance] Applying expert swap...")
else:
print(f"[Load Balance] Invalid swap, skip...")
# swap expert and gate
self._swap_moe_param(swap_list, optim)
# clear load
self._clear_load()

View File

@ -1,11 +1,9 @@
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.registry import LOSSES
from colossalai.moe.manager import MOE_MANAGER
@LOSSES.register_module
class MoeCrossEntropyLoss(_Loss):
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
@ -45,11 +43,10 @@ class MoeCrossEntropyLoss(_Loss):
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
main_loss = self.loss(*args)
aux_loss = MOE_CONTEXT.get_loss()
aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss
@LOSSES.register_module
class MoeLoss(_Loss):
"""A wrapper class for any loss module to add with auxiliary loss.
@ -77,5 +74,5 @@ class MoeLoss(_Loss):
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
"""
main_loss = self.loss_fn(*args, **kwargs)
aux_loss = MOE_CONTEXT.get_loss()
aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss

162
colossalai/moe/manager.py Normal file
View File

@ -0,0 +1,162 @@
from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.moe_tensor.api import get_moe_info
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
class MoeManager(metaclass=SingletonMeta):
"""MoE manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
def __init__(self):
self.parallel = None
self.seed = None
self.mode = None
self.use_ep_inside = None
self.world_size = None
self._parallel_info_dict = dict()
# router
self.router_aux_loss = []
self.router_z_loss = []
# fixed mode
self.pp_size = None
self.dp_size = None
self.ep_size = None
# dynamic mode
# Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = None
self.has_setup = False
@property
def parallel_info_dict(self):
return self._parallel_info_dict
@property
def is_initialized(self):
return self.has_setup
def setup(
self,
seed: int,
parallel: str = None,
mode: str = "dynamic",
max_ep_size: int = 8,
fixed_dp_size: int = 0,
fixed_ep_size: int = 0,
fixed_pp_size: int = 0,
use_ep_inside: bool = True,
) -> None:
"""
Setup MoE distributed context.
Args:
seed (int): Random seed. Defaults to 42.
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
In fixed mode, the ep size and dp size is fixed.
In dynamic mode, the ep size and dp size will be changed according to num experts.
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
"""
assert (not self.is_initialized), "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.seed = seed + dist.get_rank()
self.parallel = parallel
self.use_ep_inside = use_ep_inside
self.world_size = dist.get_world_size()
# init by mode
self.mode = mode
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
if self.mode == "dynamic":
self.max_ep_size = min(max_ep_size, self.world_size)
else:
assert (fixed_dp_size > 0 and fixed_ep_size > 0
and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0"
assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int)
and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int"
self.ep_size = fixed_ep_size
self.dp_size = fixed_dp_size
self.pp_size = fixed_pp_size
self.has_setup = True
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
"""Calculate the Data Parallel Group and Expert Parallel Group.
Parameters
----------
num_experts : int
The number experts
Returns
-------
int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size
"""
if self.mode == "dynamic":
gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater
lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less
assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa.")
dp_size = 1 if gt_flag else self.world_size // num_experts
ep_size = min(self.world_size // dp_size, self.max_ep_size)
dp_size = self.world_size // ep_size
pp_size = 1
else:
dp_size = self.dp_size
ep_size = self.ep_size
pp_size = self.pp_size
# Calculate the number of experts for each GPU
if use_tp:
num_local_experts = num_experts
else:
if self.mode == "dynamic":
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
else:
num_local_experts = num_experts // ep_size
if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
if dist.get_rank() == 0:
if self.use_ep_inside:
print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
else:
print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
return num_local_experts, self.parallel_info_dict[ep_size]
def reset_loss(self):
self.router_aux_loss, self.router_z_loss = [], []
def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
self.router_aux_loss.append(aux_loss)
self.router_z_loss.append(z_loss)
def get_loss(self):
cur_loss = self.router_aux_loss, self.router_z_loss
return cur_loss
def get_parallel(self):
return self.parallel
MOE_MANAGER = MoeManager()

419
colossalai/moe/routers.py Normal file
View File

@ -0,0 +1,419 @@
import math
from abc import ABC
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER
from colossalai.utils import get_current_device
class MoeRouter(nn.Module, ABC):
"""Base class for all MoE routers.
Args:
k_value (int): The value of top_k.
capacity_factor_train (float): Capacity factor in routing of training.
capacity_factor_eval (float): Capacity factor in routing of evaluation.
min_capacity (int): The minimum number of the capacity of each expert.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
self._aux_loss = None
self._z_loss = None
self.use_kernel = use_kernel
def get_capacity(self, logits_shape):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return int(capacity)
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
"""Computes auxiliary load balancing loss as in Switch Transformer.
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
implements the loss function presented in equations (4) - (6). It aims to
penalize those cases where the routing between experts is unbalanced.
Args:
router_probs: Probability assigned to each expert per token. Shape:
<float32>[num_groups, tokens_per_group, num_experts].
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
indices identifying the top num_selected_experts for a given token.
"""
assert self._aux_loss is None
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
# Shape: [num_groups, tokens_per_group, num_experts]
expert_mask = expert_mask.max(dim=-2)[0]
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
self._aux_loss = aux_loss
def set_z_loss(self, router_logits: torch.Tensor):
"""Compute router z-loss.
The router z-loss was introduced in Designing Effective Sparse Expert Models
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
small in an effort to improve stability.
Args:
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
"""
assert self._z_loss is None
if router_logits.dim() == 2:
router_logits = router_logits.unsqueeze(0)
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
num_groups, tokens_per_group, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, dim=-1)
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
self._z_loss = z_loss
def pop_router_loss(self) -> torch.Tensor:
assert self._aux_loss is not None
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
self._aux_loss = None
self._z_loss = None
class Top1Router(MoeRouter):
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about Switch Transformer of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0,
device=get_current_device())).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
# caculate router loss
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
elif self.select_policy == "first":
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
ranks = torch.sum(mask * ranks, dim=-1)
if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return probs, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
class Top2Router(MoeRouter):
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# caculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
if use_kernel:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
return probs, mask, dest_idx, num_experts * capacity
else:
# >>> original code
# weight1 = mask1 * probs.type_as(inputs)
# weight2 = mask2 * probs.type_as(inputs)
# rank1_sc = F.one_hot(rank1, num_classes=capacity)
# rank2_sc = F.one_hot(rank2, num_classes=capacity)
# cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
# cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
# cb_weight = cb_weight1 + cb_weight2
# sec_mask = cb_weight.bool()
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
return cb_weight, sec_mask
class TopKRouter(MoeRouter):
"""Masked matmul router using tokens choose top-k experts assignment.
NOTE: this is modified from flaxformer.
This router uses the same mechanism as in Switch Transformer
(https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
sorted by router_probs and then routed to their choice of expert until the
expert's expert_capacity is reached. There is no guarantee that each token is
processed by an expert, or that each expert receives at least one token.
Attributes:
num_selected_experts: Maximum number of experts to which each token is
routed. Tokens may be routed to fewer experts if particular experts are
oversubscribed / reach capacity.
"""
def __init__(self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
drop_tks)
def forward(
self,
router_probs: torch.Tensor,
expert_capacity: int,
) -> Tuple:
"""Computes masks for the top-k experts per token.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
# TODO: add parallel group
num_groups, _, num_experts = router_probs.shape
# Top-k router probability and corresponding expert indices for each token.
# Shape: [num_groups, tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
self.set_aux_loss(router_probs, expert_index, num_experts)
self.pop_router_loss()
# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
expert_index = torch.transpose(expert_index, 1, 2)
# Shape: [num_groups, num_selected_experts * tokens_per_group]
expert_index = expert_index.reshape(num_groups, -1)
# Create mask out of indices.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
# Experts have a fixed capacity that we cannot exceed. A token's priority
# within the expert's buffer is given by the masked, cumulative capacity of
# its target expert.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
token_priority = torch.transpose(token_priority, 1, 2)
# For each token, across all selected experts, select the only non-negative
# (unmasked) priority. Now, for group G routing to expert E, token T has
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# is its targeted expert.
# Shape: [num_groups, tokens_per_group, num_experts].
token_priority = torch.max(token_priority, dim=2)[0]
# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
return combine_array, dispatch_mask
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
if not grouped:
if top_k == 1:
return Top1Router
elif top_k == 2:
return Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
else:
return TopKRouter

177
colossalai/moe/utils.py Normal file
View File

@ -0,0 +1,177 @@
import contextlib
from typing import Any, Callable, Dict, List
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
from colossalai.utils import get_current_device
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data.clone()
class NormalNoiseGenerator:
"""Generates a random noisy mask for logits tensor.
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
`E = the number of experts`.
Args:
num_experts (int): The number of experts.
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
class UniformNoiseGenerator:
"""Generates a random noisy mask for logits tensor.
copied from mesh tensorflow:
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
Args:
eps (float, optional): Epsilon in generator, defaults 1e-2.
"""
def __init__(self, eps: float = 1e-2):
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - eps, device=get_current_device()),
high=torch.tensor(1.0 + eps, device=get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.uniform(inputs.shape)
return inputs * noisy
def autocast_softmax(logit: torch.Tensor, dim: int):
return F.softmax(logit, dim=dim, detype=torch.float32)
def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
if noise_type is None:
return None
elif noise_type == "Jitter":
noisy_func = UniformNoiseGenerator()
elif noise_type == "Gaussian":
noisy_func = NormalNoiseGenerator(num_experts)
else:
raise NotImplementedError("Unsupported input noisy policy")
return noisy_func
def get_activation(act: str) -> Callable:
if act is None or act == "relu":
return torch.nn.ReLU()
elif act == "gelu":
return torch.nn.GELU()
elif act == "swiglu":
return SwiGLU
else:
raise NotImplementedError("Unsupported activation function")
def SwiGLU(x):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size = x.shape[-1]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = torch.split(x, size // 2, -1)
return x1 * (x2 * torch.sigmoid(x2))
@contextlib.contextmanager
def skip_init():
"""
skip param random init
"""
def _skip_init(*args, **kwargs):
pass
init_func = {
"constant_": torch.nn.init.constant_,
"uniform_": torch.nn.init.uniform_,
"normal_": torch.nn.init.normal_,
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
"kaiming_normal_": torch.nn.init.kaiming_normal_,
"xavier_normal_": torch.nn.init.xavier_normal_,
"xavier_uniform_": torch.nn.init.xavier_uniform_,
"trunc_normal_": torch.nn.init.trunc_normal_,
}
for method_name, original_init in init_func.items():
setattr(torch.nn.init, method_name, _skip_init)
yield
for method_name, original_init in init_func.items():
setattr(torch.nn.init, method_name, original_init)
return
def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
"""Returns a parameter dictionary, the key of which is the expert parallel
size of every parameter. Since the parameters in data parallelism is replicated
in each GPU, we set their ep_size to 1.
Args:
model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
"""
epsize_param_dict = dict()
for param in model.parameters():
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = get_ep_size(param)
if ep_size not in epsize_param_dict:
epsize_param_dict[ep_size] = []
epsize_param_dict[ep_size].append(param)
return epsize_param_dict
def sync_moe_model_param(model: nn.Module):
"""Make sure model parameters are consistent in MoE parallel context.
Args:
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
"""
param_dict = get_moe_epsize_param_dict(model)
# synchronize the parameters whose dp_group is the whole world
if 1 in param_dict:
for param in param_dict[1]:
dist.broadcast(param, src=0)
for ep_size in param_dict:
# When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
for param in param_dict[ep_size]:
src_rank = get_dp_group_ranks(param)[0]
dist.broadcast(param, src=src_rank, group=get_dp_group(param))
def set_moe_args(config: Any, args: dict):
for k, v in args.items():
setattr(config, k, v)

View File

@ -1,2 +1 @@
# from .moe import *
from .utils import *

View File

@ -1,21 +0,0 @@
from .checkpoint import load_moe_model, save_moe_model
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, MoeModule
from .routers import MoeRouter, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [
"Experts",
"FFNExperts",
"TPExperts",
"Top1Router",
"Top2Router",
"MoeLayer",
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"build_ffn_experts",
"MoeModule",
"MoeRouter",
"save_moe_model",
"load_moe_model",
]

View File

@ -1,171 +0,0 @@
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
COL_MOE_KERNEL_FLAG = False
try:
from colossalai._C import moe
except:
moe = None
def build_moe_if_not_prebuilt():
# load moe kernel during runtime if not pre-built
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0)
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if dist.get_world_size(group) == 1:
return inputs
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
class MoeDispatch(torch.autograd.Function):
@staticmethod
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
return expert_input
@staticmethod
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
fp16_flag = expert_tokens.dtype == torch.float16
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.fp16_flag = fp16_flag
return output
@staticmethod
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and COL_MOE_KERNEL_FLAG:
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
return moe.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1

View File

@ -1,40 +0,0 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from .experts import MoeExperts
def save_moe_model(model: nn.Module, save_path: str):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, save_path)
dist.barrier()
def load_moe_model(model: nn.Module, load_path: str):
state_dict = torch.load(load_path)
for prefix, module in model.named_modules():
if prefix.endswith(".moe_layer.experts"):
# this module should be an Experts instance
assert isinstance(module, MoeExperts)
ep_rank = dist.get_rank(module.dist_info.ep_group)
num_local = module.num_local_experts
for i in range(num_local):
expert_id = ep_rank * num_local + i
for name, _ in module.experts[i].named_parameters():
cur_key = f"{prefix}.experts.{i}.{name}"
param_key = f"{prefix}.experts.{expert_id}.{name}"
load_param = state_dict[param_key]
state_dict[cur_key] = load_param
for name, _ in module.experts[0].named_parameters():
pop_pre = f"{prefix}.experts."
pop_suf = f".{name}"
for i in range(num_local, module.num_total_experts):
pop_key = f"{pop_pre}{i}{pop_suf}"
state_dict.pop(pop_key)
model.load_state_dict(state_dict)

View File

@ -1,201 +0,0 @@
import math
from copy import deepcopy
from typing import Type
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator
from colossalai.utils import get_current_device
class MoeExperts(nn.Module):
"""Basic class for experts in MoE. It stores what kind of communication experts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups.
"""
def __init__(self, comm_name: str, num_experts: int):
super().__init__()
assert comm_name in {
"all_to_all",
"all_gather",
}, "This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm_name = comm_name
self.num_total_experts = num_experts
# Get the configuration of experts' deployment and parallel information from moe context
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
@no_shard_zero_decrator(is_replicated=False)
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instance of the class, 'expert' in initialization parameters.
Args:
expert_cls (:class:`torch.nn.Module`): The class of all experts
num_experts (int): The number of experts
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts)
# Use seed to make every expert different from others
with seed(ParallelMode.TENSOR):
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
# Attach parallel information for all parameters in Experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__("moe_info", self.dist_info)
def forward(self, inputs: torch.Tensor):
# Split inputs for each expert
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
# Get outputs from each expert
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
# Concatenate all outputs together
output = torch.cat(expert_output, dim=1).contiguous()
return output
def state_dict(self, destination=None, prefix="", keep_vars=False):
assert keep_vars == False, "Only support keep_vars=False now"
dp_rank = dist.get_rank(self.dist_info.dp_group)
ep_rank = dist.get_rank(self.dist_info.ep_group)
submodule_dict = dict()
example_submodule = None
for name, subm in self.experts.named_modules():
if subm is self.experts:
continue
module_number = self.num_local_experts * ep_rank + int(name)
submodule_dict[module_number] = subm
example_submodule = subm
if dp_rank == 0:
local_prefix = prefix + "experts."
buffer_module = deepcopy(example_submodule)
for i in range(self.num_total_experts):
source_rank = i // self.num_local_experts
current_prefix = local_prefix + str(i) + "."
comm_module = submodule_dict.get(i, buffer_module)
for name, param in comm_module.named_parameters():
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
if ep_rank == 0:
destination[current_prefix + name] = param.data.cpu()
dist.barrier()
class FFNExperts(MoeExperts):
"""Use torch.bmm to speed up for multiple experts."""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_to_all", num_experts)
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters():
param.__setattr__("moe_info", self.dist_info)
def forward(self, inputs): # inputs [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs
class TPExperts(MoeExperts):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divide by maximum expert parallel size or
maximum expert parallel size can't be divide by the number of experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size"
p_ff = d_ff // MOE_CONTEXT.max_ep_size
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__("moe_info", self.dist_info)
self.w2.__setattr__("moe_info", self.dist_info)
self.b1.__setattr__("moe_info", self.dist_info)
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]

View File

@ -1,212 +0,0 @@
import math
from typing import Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from colossalai.nn.layer.moe._operation import (
COL_MOE_KERNEL_FLAG,
AllGather,
AllToAll,
MoeCombine,
MoeDispatch,
ReduceScatter,
)
from colossalai.nn.layer.moe.experts import Experts, MoeExperts
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
from colossalai.utils import get_current_device
@no_shard_zero_decrator(is_replicated=True)
class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
the moe tensor group by all to all communication. Then it will get the output of all
experts and exchange the output. At last returns the output of the moe system.
Args:
dim_model (int): Dimension of model.
num_experts (int): The number of experts.
router (MoeRouter): Instance of router used in routing.
experts (MoeExperts): Instance of experts generated by Expert.
"""
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
self.router: MoeRouter = router
self.experts: MoeExperts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
self.ep_group = experts.dist_info.ep_group
self.ep_size = experts.dist_info.ep_size
self.num_local_experts = experts.num_local_experts
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output
def tp_process(self, dispatch_data: torch.Tensor):
expert_in = AllGather.apply(dispatch_data, self.ep_group)
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
return expert_out
def forward(self, inputs: torch.Tensor) -> Tuple:
# reshape the input tokens
tokens = inputs.reshape(-1, self.d_model)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else:
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h]
if self.experts.comm_name == "all_to_all":
expert_output = self.a2a_process(dispatch_data)
elif self.experts.comm_name == "all_gather":
expert_output = self.tp_process(dispatch_data)
else:
raise NotImplementedError(
"This kind of communication has not been implemented yet.\n Please use Experts " "build function."
)
# expert_output [e, c, h]
if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *route_result_list)
else:
combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
l_aux = self.router.pop_routing_loss()
return ans, l_aux
class MoeModule(nn.Module):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in `Switch Transformer paper`_.
'Gaussian' can be found in `ViT-MoE paper`_.
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in `Microsoft paper`_.
residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
.. _Switch Transformer paper:
https://arxiv.org/abs/2101.03961
.. _ViT-MoE paper:
https://arxiv.org/abs/2106.05974
.. _Microsoft paper:
https://arxiv.org/abs/2201.05596
"""
def __init__(
self,
dim_model: int,
num_experts: int,
top_k: int = 1,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_policy: Optional[str] = None,
drop_tks: bool = True,
use_residual: bool = False,
residual_instance: Optional[nn.Module] = None,
expert_instance: Optional[MoeExperts] = None,
expert_cls: Optional[Type[nn.Module]] = None,
**expert_args,
):
super().__init__()
noisy_func = None
if noisy_policy is not None:
if noisy_policy == "Jitter":
noisy_func = UniformNoiseGenerator()
elif noisy_policy == "Gaussian":
noisy_func = NormalNoiseGenerator(num_experts)
else:
raise NotImplementedError("Unsupported input noisy policy")
if top_k == 1:
moe_router_cls = Top1Router
elif top_k == 2:
moe_router_cls = Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
self.moe_router = moe_router_cls(
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.use_residual = use_residual
if use_residual:
if residual_instance is not None:
self.residual_module = residual_instance
else:
assert expert_cls is not None, "Expert class can't be None when residual instance is not given"
self.residual_module = expert_cls(**expert_args)
with no_shard_zero_context():
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None:
my_experts = expert_instance
else:
assert expert_cls is not None, "Expert class can't be None when experts instance is not given"
my_experts = Experts(expert_cls, num_experts, **expert_args)
self.moe_layer = MoeLayer(
dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts
)
def forward(self, inputs: torch.Tensor):
moe_output, l_aux = self.moe_layer(inputs)
if self.use_residual:
residual_output = self.residual_module(inputs)
combine_coef = self.residual_combine(inputs)
combine_coef = F.softmax(combine_coef, dim=-1)
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
else:
output = moe_output
return output, l_aux

View File

@ -1,235 +0,0 @@
import math
from abc import ABC
from typing import Callable, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from colossalai.nn.layer.moe._operation import moe_cumsum
from colossalai.utils import get_current_device
class MoeRouter(nn.Module, ABC):
"""Base class for all MoE routers.
Args:
k_value (int): The value of top_k.
capacity_factor_train (float): Capacity factor in routing of training.
capacity_factor_eval (float): Capacity factor in routing of evaluation.
min_capacity (int): The minimum number of the capacity of each expert.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(
self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Callable = None,
drop_tks: bool = True,
):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
self._routing_loss = None
def get_capacity(self, logits_shape):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity
def set_routing_loss(self, aux_loss: torch.Tensor) -> None:
assert self._routing_loss is None
self._routing_loss = aux_loss
def pop_routing_loss(self) -> torch.Tensor:
assert self._routing_loss is not None
reservation = self._routing_loss
self._routing_loss = None
return reservation
class Top1Router(MoeRouter):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More detailed function can be found in the paper about Switch Transformer
of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Callable = None,
drop_tks: bool = True,
):
super().__init__(
k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
logits = F.softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
# caculate the auxiliary loss
me = torch.mean(logits, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce)
self.set_routing_loss(l_aux)
if not self.training and not self.drop_tks:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask)
elif self.select_policy == "first":
ranks = moe_cumsum(mask)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
ranks = torch.sum(mask * ranks, dim=-1)
if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * logits.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
class Top2Router(MoeRouter):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More detailed function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Callable = None,
drop_tks: bool = True,
):
super().__init__(
k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
# inputs: [s, h]
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(logits, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = mask1 + mask2 # loss: [s, e]
# caculate the auxiliary loss
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
self.set_routing_loss(l_aux)
if not self.training and not self.drop_tks:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1) # rank1: [s, e]
rank2 = moe_cumsum(mask2)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
if use_kernel:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
else:
weight1 = mask1 * logits.type_as(inputs)
weight2 = mask2 * logits.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
return cb_weight, sec_mask

View File

@ -1,71 +0,0 @@
import torch
import torch.nn.functional as F
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from .experts import FFNExperts, TPExperts
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data.clone()
class NormalNoiseGenerator:
"""Generates a random noisy mask for logits tensor.
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
`E = the number of experts`.
Args:
num_experts (int): The number of experts.
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
class UniformNoiseGenerator:
"""Generates a random noisy mask for logits tensor.
copied from mesh tensorflow:
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
Args:
eps (float, optional): Epsilon in generator, defaults 1e-2.
"""
def __init__(self, eps: float = 1e-2):
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - eps, device=get_current_device()),
high=torch.tensor(1.0 + eps, device=get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.uniform(inputs.shape)
return inputs * noisy
def autocast_softmax(logit: torch.Tensor, dim: int):
if logit.dtype != torch.float32:
logit = logit.float()
return F.softmax(logit, dim=dim)
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
mep_size = MOE_CONTEXT.max_ep_size
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % mep_size == 0:
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
else:
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")

View File

@ -1 +0,0 @@
# from .loss_moe import MoeCrossEntropyLoss, MoeLoss

View File

View File

@ -0,0 +1,137 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from .moe_info import MoeParallelInfo
def is_moe_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a moe tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a moe tensor.
"""
return hasattr(tensor, "moe_info")
def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None:
"""
Set moe info for the given tensor.
Args:
tensor (torch.Tensor): The tensor to be set.
moe_info (dict): The moe info to be set.
"""
tensor.__setattr__("moe_info", moe_info)
def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo:
"""
Get moe info for the given tensor.
Args:
ep_size (int): The expert parallel size.
dp_size (int): The data parallel size.
pp_size (int): The pipeline parallel size.
ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle.
Returns:
dict: The moe info of the given tensor.
"""
return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size)
def get_ep_group(tensor: torch.Tensor) -> ProcessGroup:
"""
Get the expert parallel group of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
torch.distributed.ProcessGroup: The expert parallel group of the given tensor.
"""
return tensor.moe_info.ep_group
def get_ep_size(tensor: torch.Tensor) -> int:
"""
Get the expert parallel size of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
int: The expert parallel size of the given tensor.
"""
return tensor.moe_info.ep_size
def get_dp_group(tensor: torch.Tensor) -> ProcessGroup:
"""
Get the data parallel group of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
torch.distributed.ProcessGroup: The data parallel group of the given tensor.
"""
return tensor.moe_info.dp_group
def get_ep_rank(tensor: torch.Tensor) -> int:
"""
Get the expert parallel rank of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
int: The expert parallel rank of the given tensor.
"""
return dist.get_rank(get_ep_group(tensor))
def get_dp_rank(tensor: torch.Tensor) -> int:
"""
Get the data parallel rank of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
int: The data parallel rank of the given tensor.
"""
return dist.get_rank(get_dp_group(tensor))
def get_ep_group_ranks(tensor: torch.Tensor) -> int:
"""
Get the expert parallel group ranks of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
int: The expert parallel group ranks of the given tensor.
"""
return tensor.moe_info.ep_group_ranks
def get_dp_group_ranks(tensor: torch.Tensor) -> int:
"""
Get the data parallel group ranks of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
int: The data parallel group ranks of the given tensor.
"""
return tensor.moe_info.dp_group_ranks

View File

@ -0,0 +1,28 @@
from colossalai.cluster import ProcessGroupMesh
class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups."""
def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1):
"""
init MoeParallelInfo with ep_size, dp_size and pp_size
Args:
ep_size (int): expert parallel size
dp_size (int): data parallel (zero) size
pp_size (int, optional): pipeline parallel size. Defaults to 1.
ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
"""
self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size
if ep_inside:
self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2
self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size)
else:
self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2
self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size)
self.ep_group = self.pg.get_group_along_axis(self.ep_axis)
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)

View File

@ -1,53 +0,0 @@
from typing import Dict, List
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import is_using_ddp
def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
"""Returns a parameter dictionary, the key of which is the expert parallel
size of every parameter. Since the parameters in data parallelism is replicated
in each GPU, we set their ep_size to 1.
Args:
model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
"""
epsize_param_dict = dict()
for param in model.parameters():
if not hasattr(param, "moe_info"):
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = param.moe_info.ep_size
if ep_size not in epsize_param_dict:
epsize_param_dict[ep_size] = []
epsize_param_dict[ep_size].append(param)
return epsize_param_dict
def sync_moe_model_param(model: nn.Module):
"""Make sure model parameters are consistent in MoE parallel context.
Args:
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
"""
if is_using_ddp():
param_dict = get_moe_epsize_param_dict(model)
# synchronize the parameters whose dp_group is the whole world
if 1 in param_dict:
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
for param in param_dict[1]:
dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA))
for ep_size in param_dict:
# When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group)
for param in param_dict[ep_size]:
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)

View File

@ -8,6 +8,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
@ -18,6 +19,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
@ -75,6 +77,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
forced_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
master_weights: bool = True, # master weights
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -95,6 +98,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg)
# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
# And moe working and master param are split by extra dp pg.
self.moe_extra_dp_pg = moe_extra_dp_process_group
if self.moe_extra_dp_pg is not None:
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict()
@ -126,6 +139,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
self._bucket_store = BucketStore(self.dp_pg)
# moe param should not be stored in working_groups
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
moe_params = list()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
@ -133,6 +152,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
group_params = list()
for param in param_group["params"]:
if param.requires_grad:
if self.moe_extra_dp_pg is None:
# skip moe param
if is_moe_tensor(param):
moe_params.append(param)
continue
group_params.append(param)
# add the working params to working_param_groups for bookkeeping
@ -146,6 +170,15 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
param_group["params"] = master_param_current_rank
# if there are moe params, store in addtional group in optim
if len(moe_params) > 0:
param_group = dict()
for key, value in self.optim.param_groups[0].items():
if key != "params":
param_group[key] = value
param_group["params"] = moe_params
self.optim.param_groups.append(param_group)
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
@ -208,13 +241,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // self._world_size)
if self.moe_extra_dp_pg is not None and is_moe_tensor(param):
splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size)
splited_params = splited_params[self.moe_extra_dp_pg_rank]
else:
splited_params = padding_param.split(padding_param.numel() // self._world_size)
splited_params = splited_params[self._local_rank]
# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
splited_param_current_rank = splited_params.detach().float().to(device)
else:
splited_param_current_rank = splited_params[self._local_rank]
splited_param_current_rank = splited_params
params_current_rank.append(splited_param_current_rank)
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
@ -247,8 +287,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()
flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size
if self.moe_extra_dp_pg is None:
flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size
else:
# record moe and non moe param
moe_list = []
for param in self._bucket_store._param_list:
moe_list.append(is_moe_tensor(param))
# divide them into different groups
moe_grad_list = []
non_moe_grad_list = []
for grad_list in self._bucket_store._grad_in_bucket.values():
non_moe_cur_grad = []
moe_cur_grad = []
for i in range(len(grad_list)):
if moe_list[i] == True:
moe_cur_grad.append(grad_list[i])
else:
non_moe_cur_grad.append(grad_list[i])
if len(moe_cur_grad) > 0:
moe_grad_list.append(moe_cur_grad)
if len(non_moe_cur_grad) > 0:
non_moe_grad_list.append(non_moe_cur_grad)
if len(non_moe_grad_list) > 0:
non_moe_flat_grads = []
for grad_list in non_moe_grad_list:
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
non_moe_flat_grads /= self._world_size
if len(moe_grad_list) > 0:
moe_flat_grads = []
for grad_list in moe_grad_list:
moe_flat_grads.append(_flatten_dense_tensors(grad_list))
moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
# ready to add other tensors to bucket
self._bucket_store.reset_num_elements_in_bucket()
@ -256,7 +331,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self._overlap_communication:
stream = self._comm_stream
# in case of the memory being reused in the default stream
flat_grads.record_stream(stream)
if self.moe_extra_dp_pg is None:
flat_grads.record_stream(stream)
else:
if len(non_moe_grad_list) > 0:
non_moe_flat_grads.record_stream(stream)
if len(moe_grad_list) > 0:
moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
stream.wait_stream(torch.cuda.current_stream())
else:
@ -265,49 +346,108 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
with torch.cuda.stream(stream):
group_id = self._bucket_store.current_group_id
grad_dtype = flat_grads.dtype
if self._communication_dtype is not None:
flat_grads = flat_grads.to(self._communication_dtype)
if self.moe_extra_dp_pg is None:
grad_dtype = flat_grads.dtype
if self._communication_dtype is not None:
flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
dist.all_reduce(flat_grads, group=self.dp_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
if self.moe_extra_dp_pg is None:
dist.all_reduce(flat_grads, group=self.dp_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
grad_in_bucket = self._bucket_store.get_grad()
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
grad_in_bucket = self._bucket_store.get_grad()
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
for rank, grad_list in grad_in_bucket.items():
sync_tensor(flat_grads_per_rank[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
if (
len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id))
< self._world_size
):
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
# sync extra zero group
else:
# sync non moe param in global dp group
if len(non_moe_grad_list) > 0:
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
flat_grads_per_rank = non_moe_flat_grads.split(
non_moe_flat_grads.numel() // self._world_size
)
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
# sync moe param only in zero group
if len(moe_grad_list) > 0:
dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg)
flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size)
self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
if self.moe_extra_dp_pg is None:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
for grad in grad_in_bucket_current_rank:
param_id = self._bucket_store.get_param_id_of_grad(grad)
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
else:
# categorize moe and non moe param
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
moe_grad_in_bucket_current_rank = []
non_moe_grad_in_bucket_current_rank = []
for idx, grad in enumerate(grad_in_bucket_current_rank):
if moe_list[idx] == True:
moe_grad_in_bucket_current_rank.append(grad)
else:
non_moe_grad_in_bucket_current_rank.append(grad)
if len(non_moe_grad_list) > 0:
flat_grads_list = list(
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
self._update_partitoned_grad(
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
)
if len(moe_grad_list) > 0:
flat_grads_list = list(
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
param_slice = self._world_size // self.moe_extra_dp_pg_size
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
for split_recieved_grad in recieved_grad:
split_recieved_grad = _unflatten_dense_tensors(
split_recieved_grad, moe_grad_in_bucket_current_rank
)
for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(real_grad, param_slice, group_id, param_id)
self._bucket_store.reset()
def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None:
for rank, grad_list in enumerate(origin_grad_list):
sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, self._world_size, group_id, param_id, rank)
def _update_partitoned_grad(
self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
) -> None:
sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, partition_num, group_id, param_id)
def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
def _add_to_bucket(self, param, group_id):
param_size = param.numel()
@ -424,13 +564,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# else the splited grad should be attached to the splited param
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0:
real_working_params[group_id].append(working_param)
# moe hybrid zero
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
real_working_params[group_id].append(working_param)
if self._partition_grads:
grad = grads
else:
param_slice = self._world_size // self.moe_extra_dp_pg_size
grad = grads[
self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice
]
grad = flatten(grad)
else:
real_working_params[group_id].append(working_param)
grad = grads[grad_index]
# no need to copy fp32 grad if master_weights is False
grad = (
grads[grad_index].to(splited_param.dtype).to(splited_param.device)
if self._master_weights
else grads[grad_index]
)
if self._master_weights:
grad = grad.to(splited_param.dtype).to(splited_param.device)
splited_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(splited_param)
@ -449,24 +599,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# TODO: we should store master param for ep
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.data = param.data.to(torch.float32)
param.grad = param.grad.to(torch.float32)
# update the parameters
self.optim.step()
# release the moe gradm
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.grad = None
param.data = param.data.to(self._dtype)
# release the grad
grad_partition_groups = []
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
# dtype = real_working_params[0][0].dtype
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg)
else:
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
@ -488,7 +657,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
@ -596,10 +764,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg)
else:
gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@ -624,8 +798,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
v_list = v.split(v.numel() // self.moe_extra_dp_pg_size)
zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone()
else:
v_list = v.split(v.numel() // self._world_size)
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
@ -656,8 +834,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg)
else:
state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@ -688,7 +874,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param

View File

@ -0,0 +1,129 @@
## OpenMoE
[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.
## Usage
### 1. Installation
Please install the latest ColossalAI from source.
```bash
CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
```
Then install dependencies.
```bash
cd ColossalAI/examples/language/openmoe
pip install -r requirements.txt
```
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
### 2. Install kernels (Optional)
We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware.
```
# install triton via pip
pip install triton
# install flash attention via pip
pip install flash-attn==2.0.5
# install apex from source
git clone https://github.com/NVIDIA/apex.git
cd apex
git checkout 741bdf50825a97664db08574981962d66436d16a
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext"
```
### 3. Train
Yon can use colossalai run to launch single-node training:
```bash
colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS
```
Yon can also use colossalai run to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS
```
Here is a sample hostfile:
```text
hostname1
hostname2
hostname3
hostname4
```
The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password.
Here is details about CLI arguments:
- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE.
- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training.
- Output path: `--output_path`. The path to save your model. The default value is `./outputs`.
- Number of epochs: `--num_epochs`. The default value is 1.
- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported.
- Max length: `--max_length`. Max sequence length. Default to 2048.
- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it.
- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`.
- Learning rate: `--lr`. The default value is 1e-5.
- Weight decay: `--weight_decay`. The default value is 0.
- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero.
- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4.
- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.
- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed.
- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details.
- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details.
- Label smoothing: `--label_smoothing`. Label smoothing.
- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor.
Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling.
- Load balance interval: `--load_balance_interval`. Expert load balance interval.
- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training.
### 4. Shell Script Examples
For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training
OpenMoE.
#### a. Running environment
This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink.
#### b. Running command
We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args.
```bash
bash train.sh
```
#### c. Multi-Nodes Training
To run on multi-nodes, you can modify the script as:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
train.py --OTHER_CONFIGURATIONS
```
## Reference
```
@article{bian2021colossal,
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
journal={arXiv preprint arXiv:2110.14883},
year={2021}
}
```
```bibtex
@misc{openmoe2023,
author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You},
title = {OpenMoE: Open Mixture-of-Experts Language Models},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}},
}
```

View File

@ -0,0 +1,296 @@
import argparse
import json
import os
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# single ckpt
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
# shard ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
class RandomDataset(Dataset):
def __init__(
self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None
):
self.num_samples = num_samples
self.max_length = max_length
if os.path.exists("./mock_data.json"):
self.input_ids = []
self.attention_mask = []
with open("./mock_data.json", "r") as f:
data = json.load(f)
for v in data.values():
d = v["text"]
encode = tokenizer(
"<pad>" + d,
return_tensors="pt",
add_special_tokens=False,
max_length=max_length,
truncation=True,
padding="max_length",
)
self.input_ids.append(encode["input_ids"])
self.attention_mask.append(encode["attention_mask"])
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
repeat_times = num_samples // self.input_ids.shape[0] + 1
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
else:
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--seq_length",
type=int,
default=2048,
help="sequence length for the training dataloader.",
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
help="parallel plugin",
)
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size")
parser.add_argument("--dp_size", type=int, default=1, help="dp size")
parser.add_argument("--ep_size", type=int, default=2, help="ep size")
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")
parser.add_argument("--extra_dp_size", type=int, default=1)
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",
)
# bench
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
# load balance
parser.add_argument("--load_balance", action="store_true")
# overlap
parser.add_argument("--overlap_alltoall", action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": OpenMoeForCausalLMPolicy(),
"enable_fused_normalization": args.use_kernel,
"enable_jit_fused": args.use_kernel,
"precision": "bf16",
"zero_stage": args.zero_stage,
}
mgr_dict = {
"seed": 42,
}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size // args.extra_dp_size,
use_ep_inside=use_ep_inside,
**mgr_dict,
)
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
zero_stage=args.zero_stage,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin}")
# Build OpenMoe model
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=args.load_balance,
enable_kernel=args.use_kernel,
enable_comm_overlap=args.overlap_alltoall,
)
with skip_init():
model = OpenMoeForCausalLM(config)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
dataset = RandomDataset(
num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
max_length=args.seq_length,
tokenizer=tokenizer,
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)
# Set optimizer
optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)
model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
model_numel,
enable_grad_checkpoint=True,
ignore_steps=args.warmup,
dp_world_size=dp_size,
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
load_ckpt(repo_name, model, booster)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Start finetuning
coordinator.print_on_master(f"Start training")
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter) - 1
exmaple_data = next(train_dataloader_iter)
with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:
for step in pbar:
performance_evaluator.on_step_start(step)
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(exmaple_data["input_ids"])
if (step == args.warmup // 2) and args.load_balance:
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
performance_evaluator.on_fit_end()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,78 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
WARMUP=20
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# ep
echo -e "\n\n Naive EP \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 8 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep \
--zero_stage 2
# ep_zero
echo -e "\n\n EP-ZERO \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 16 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance
echo -e "\n\n EP-ZERO + Overlap \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 16 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance \
--overlap_alltoall
# hybrid
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 128 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--use_kernel \
--plugin hybrid \
--pp_size 2 \
--dp_size 1 \
--ep_size 4 \
--zero_stage 1 \
--microbatch_size 32

View File

@ -0,0 +1,47 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
WARMUP=20
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# ep
echo -e "\n\n Naive EP \n\n"
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 12 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep \
--zero_stage 2
# ep_zero
echo -e "\n\n EP-ZERO \n\n"
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 20 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance \
--overlap_alltoall

View File

@ -0,0 +1,139 @@
import argparse
import functools
import os
import torch
import torch.distributed as dist
import tqdm
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
from colossalai.moe.manager import MOE_MANAGER
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def fsdp_main(rank, world_size, args):
# initialize the process group
# initialize the process group
dist.init_process_group("nccl")
MOE_MANAGER.setup(seed=42, parallel=None)
dp_size = dist.get_world_size()
dataset = RandomDataset(
max_length=args.seq_length,
num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
)
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
torch.cuda.set_device(rank)
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=False,
enable_kernel=False,
enable_comm_overlap=False,
)
torch.set_default_dtype(torch.float16)
model = OpenMoeForCausalLM(config)
torch.set_default_dtype(torch.float32)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
OpenMoeDecoderLayer,
},
)
model = FSDP(
model,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
model.train()
model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
model_numel,
enable_grad_checkpoint=True,
ignore_steps=args.warmup,
dp_world_size=dist.get_world_size(),
)
for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
performance_evaluator.on_step_start(step)
input_ids, attention_mask, labels = (
data["input_ids"].cuda(),
data["attention_mask"].cuda(),
data["labels"].cuda(),
)
optimizer.zero_grad()
output = model(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
chunk_head=False,
)
loss = output["loss"]
loss.backward()
optimizer.step()
performance_evaluator.on_step_end(input_ids)
performance_evaluator.on_fit_end()
if dist.get_rank() == 0:
print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
help="base or 8b",
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seq_length", type=int, default=2048)
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
args = parser.parse_args()
torch.manual_seed(42)
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
fsdp_main(local_rank, world_size, args)

View File

@ -0,0 +1,34 @@
#!/bin/bash
set -xue
MODEL="8b"
BATCH_SIZE=1
SEQ_LENGTH=2048
WARMUP=8
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# single node
torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \
--model_name $MODEL \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE
# multi node
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \
$example_dir/benchmark/benchmark_fsdp.py \
--model_name $MODEL \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE

View File

@ -0,0 +1,2 @@
host1
host2

View File

@ -0,0 +1,126 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from colossalai.logging import DistributedLogger
def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = "Model param count: "
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
if model_param >= B:
outputs += f"{model_param / B:.2f} B\n"
elif model_param >= M:
outputs += f"{model_param / M:.2f} M\n"
elif model_param >= K:
outputs += f"{model_param / K:.2f} K\n"
else:
outputs += f"{model_param}\n"
logger.info(outputs, ranks=[0])
def get_model_numel(model: nn.Module) -> None:
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
return model_param
def divide(x: float, y: float) -> float:
if y == 0:
return float("inf")
elif y == float("inf"):
return float("nan")
return x / y
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
def end(self) -> None:
assert self.start_time is not None
self.duration += time() - self.start_time
self.start_time = None
def reset(self) -> None:
self.duration = 0.0
class PerformanceEvaluator:
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(
self,
model_numel: int,
enable_grad_checkpoint: bool = False,
ignore_steps: int = 0,
dp_world_size: Optional[int] = None,
) -> None:
self.model_numel = model_numel
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_steps = ignore_steps
self.dp_world_size = dp_world_size
self.world_size = dist.get_world_size()
self.disable: bool = False
self.timer = Timer()
self.num_samples: int = 0
self.flop: int = 0
def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
torch.cuda.synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
torch.cuda.synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape
self.num_samples += batch_size
self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)))
def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.world_size // self.dp_world_size
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
if dist.get_rank() == 0:
print(
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"avg_throughput: {avg_throughput}")
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")

View File

@ -0,0 +1,57 @@
from argparse import ArgumentParser
import torch
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
def parse_args():
parser = ArgumentParser()
parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
return parser.parse_args()
def inference(args):
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
if args.model == "test":
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_kernel=True)
model = OpenMoeForCausalLM(config)
else:
config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}")
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_kernel=False)
model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config)
model = model.eval().bfloat16()
model = model.to(torch.cuda.current_device())
input_str = """```
y = list(map(int, ['1', 'hello', '2']))
```
What error does this program produce?
ValueError: invalid literal for int() with base 10: 'hello'
```
sum = 0
for i in range(100):
sum += i
```
What is the value of sum immediately after the 10th time line 3 is executed?"""
# print("model config: ", model.config)
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=False)
input_ids = input_ids.input_ids.to(torch.cuda.current_device())
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64)
out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
print(f"output: \n{out}\n")
if __name__ == "__main__":
args = parse_args()
inference(args)

View File

@ -0,0 +1 @@
python infer.py --model "base"

View File

@ -0,0 +1,224 @@
# coding=utf-8
# Copyright 2022 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert T5X checkpoint to PyTorch
Steps:
- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install
- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example:
`gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/`
- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use
https://huggingface.co/google/t5-v1_1-small/blob/main/config.json
- Convert:
```
python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\
--pytorch_dump_path=$HOME/t5_1_1_small_pt
```
"""
import argparse
import collections
import torch
from flax import traverse_util
from modeling_openmoe import OpenMoeForCausalLM
from t5x import checkpoints
from transformers import LlamaConfig
from transformers.utils import logging
logging.set_verbosity_info()
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
"""Returns the KOQV parameters of (self-)attention. Does not transpose."""
k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"]
o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"]
q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"]
v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"]
return k, o, q, v
def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"]
return wi, wo
def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"]
return wi, wo
def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"]
return wi, wo
def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"]
def t5x_layer_norm_lookup(params, i, prefix, layer_name):
"""Returns the layer norm param of a layer."""
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()}
# v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi
split_mlp_wi = True
print("Split MLP:", split_mlp_wi)
new = collections.OrderedDict()
print(old.keys())
for key, value in old.items():
print(f"{key}: {value.shape}")
# Shared embeddings.
new["model.embed_tokens.weight"] = old["token_embedder/embedding"]
# Decoder.
for i in range(num_layers):
# Block i, layer 0 (Self Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm
new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T
new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T
new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T
new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T
# Block i, layer 2 (MLP).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm
if (i + 1) % moe_interval == 0:
# moe
gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.gate_weight"] = gate.T
wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0]
new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1]
new[f"model.layers.{i}.mlp.experts.wo"] = wo
# extra
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm")
new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm
wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T
new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T
new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T
else:
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T
new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T
new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T
new["model.norm.weight"] = old["decoder/decoder_norm/scale"]
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
if "decoder/logits_dense/kernel" in old:
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
return new
def make_state_dict(converted_params):
"""Prepares a state dict for the PyTorch model."""
# Make a state dict with torch tensors.
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
"""Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables,
num_layers=config.num_hidden_layers,
moe_interval=config.moe_layer_interval)
state_dict = make_state_dict(converted)
model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# Initialise PyTorch model
config = LlamaConfig.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
model = OpenMoeForCausalLM(config)
# Load weights from tf checkpoint
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
# Verify that we can load the checkpoint.
model.from_pretrained(pytorch_dump_path)
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
# Required parameters
parser.add_argument("--t5x_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the T5X checkpoint.")
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
)
parser.add_argument("--pytorch_dump_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)

View File

@ -0,0 +1 @@
python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
{
"architectures": [
"OpenMoeForCausalLM"
],
"intermediate_size": 8192,
"hidden_size": 2048,
"num_hidden_layers": 24,
"head_dim": 128,
"num_attention_heads": 24,
"dropout_rate": 0.0,
"layer_norm_epsilon": 1e-06,
"vocab_size": 256384,
"hidden_act": "swiglu",
"num_experts": 32,
"topk": 2,
"capacity_factor_train": 1.25,
"capacity_factor_eval": 2.0,
"min_capacity": 4,
"noisy_policy": null,
"drop_tks": true,
"expert_parallel": null,
"gated": true,
"moe_layer_interval": 6
}

View File

@ -0,0 +1,24 @@
{
"architectures": [
"OpenMoeForCausalLM"
],
"intermediate_size": 2048,
"hidden_size": 768,
"num_hidden_layers": 12,
"head_dim": 64,
"num_attention_heads": 12,
"dropout_rate": 0.0,
"layer_norm_epsilon": 1e-06,
"vocab_size": 256384,
"hidden_act": "swiglu",
"num_experts": 16,
"topk": 2,
"capacity_factor_train": 1.25,
"capacity_factor_eval": 2.0,
"min_capacity": 4,
"noisy_policy": null,
"drop_tks": true,
"expert_parallel": null,
"gated": true,
"moe_layer_interval": 4
}

View File

@ -0,0 +1,562 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import logging
from colossalai.moe.manager import MOE_MANAGER
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel
__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
class OpenMoePolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="pre_extra_mlp_layernorm",
target_module=FusedRMSNorm,
ignore_if_not_exist=True,
),
],
policy=policy,
target_key=OpenMoeDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=OpenMoeModel,
)
if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in openmoe.")
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "OpenMoeModel":
module = self.model
else:
module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "OpenMoeModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages
"""
if num_layers == 24 and num_stages == 4:
return [7, 7, 7, 3]
elif num_layers == 24 and num_stages == 2:
return [15, 9]
elif num_layers == 12 and num_stages == 4:
return [5, 5, 5, 1]
elif num_layers == 12 and num_stages == 2:
return [8, 4]
else:
print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy")
return Policy.distribute_layers(num_layers, num_stages)
class OpenMoeModelPolicy(OpenMoePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=OpenMoeModel,
new_forward=OpenMoePipelineForwards.openmoe_model_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
OpenMoeForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
])
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=OpenMoeForCausalLM,
new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1):
# tie weights
return [{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}]
return []
class OpenMoePipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
"""
@staticmethod
def openmoe_model_forward(
self: OpenMoeModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
past_router_aux_loss: Optional[torch.FloatTensor] = None,
past_router_z_loss: Optional[torch.FloatTensor] = None,
):
# reset moe loss for different data
MOE_MANAGER.reset_loss()
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=hidden_states.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
# concat past losses with current ones
router_aux_loss, router_z_loss = MOE_MANAGER.get_loss()
if past_router_aux_loss is not None and past_router_z_loss is not None:
router_aux_loss = past_router_aux_loss + router_aux_loss
router_z_loss = past_router_z_loss + router_z_loss
if stage_manager.is_last_stage():
return tuple([
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
])
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
"router_aux_loss": router_aux_loss,
"router_z_loss": router_z_loss,
}
@staticmethod
def llama_for_causal_lm_forward(
self: OpenMoeForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
chunk_head: Optional[bool] = True,
past_router_aux_loss: Optional[torch.FloatTensor] = None,
past_router_z_loss: Optional[torch.FloatTensor] = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = OpenMoePipelineForwards.openmoe_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
past_router_aux_loss=past_router_aux_loss,
past_router_z_loss=past_router_z_loss,
)
if stage_manager.is_last_stage():
(
hidden_states,
past_key_values,
all_hidden_states,
attentions,
router_aux_loss,
router_z_loss,
) = outputs
if self.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
loss = None
# if no training, just do forward
if labels is None:
logits = self.lm_head(hidden_states)
logits = logits.float()
# the vocab size for openmoe is 30w+
# which causes great activation memory in training, up to 20G for one sequence
# so we use chunk and checkpoint to reduce memory
else:
if chunk_head == True:
def create_custom_forward(module):
def custom_forward(*inputs):
logits = module(inputs[0])
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous().float()
shift_labels = inputs[1][..., 1:].contiguous()
# Flatten the tokens
loss = self._calculate_loss(shift_logits, shift_labels)
return loss
return custom_forward
aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
loss = aux_loss + z_loss
for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head),
hidden_states[batch_idx:batch_idx + 1, :],
labels[batch_idx:batch_idx + 1, :],
)
logits = None
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
loss = aux_loss + z_loss
loss = loss + self._calculate_loss(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=attentions,
)
else:
hidden_states = outputs["hidden_states"]
router_aux_loss = outputs["router_aux_loss"]
router_z_loss = outputs["router_z_loss"]
return {
"hidden_states": hidden_states,
"past_router_aux_loss": router_aux_loss,
"past_router_z_loss": router_z_loss,
}

View File

@ -0,0 +1,5 @@
colossalai >= 0.3.3
torch >= 1.8.1
transformers >= 4.20.0
sentencepiece
datasets

View File

@ -0,0 +1,37 @@
pip install -r requirements.txt
# inference
python infer.py --model "test"
# train
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep" \
--batch_size 1
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1 \
--zero_stage 1 \
--extra_dp_size 2 \
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1 \
--zero_stage 2 \
--extra_dp_size 2 \
torchrun --standalone --nproc_per_node 4 train.py \
--model_name "test" \
--plugin "hybrid" \
--num_epoch 1 \
--pp_size 2 \
--dp_size 1 \
--ep_size 2 \
--zero_stage 1 \
--batch_size 1

View File

@ -0,0 +1,377 @@
import argparse
import os
from functools import partial
from typing import Dict
import torch
import torch.distributed as dist
from datasets import load_dataset
from huggingface_hub import snapshot_download
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# single ckpt
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
# shard ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
data = tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
add_special_tokens=False,
)
data = {k: v.cuda() for k, v in data.items()}
data["labels"] = data["input_ids"].clone()
return data
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b", "test"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["ep", "ep_zero", "hybrid"],
help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help=" The interval (steps) of saving checkpoints.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument(
"--dataset",
type=str,
default="yizhongw/self_instruct",
help="dataset name from `datasets` repo.",
)
parser.add_argument(
"--task_name",
type=str,
default="super_natural_instructions",
help="task of corresponding dataset.",
)
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
# ep_zero plugin
parser.add_argument(
"--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4."
)
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
# loss
parser.add_argument(
"--router_aux_loss_factor",
type=float,
default=0.01,
help="Moe router z loss. You can refer to STMoE for details.",
)
parser.add_argument(
"--router_z_loss_factor",
type=float,
default=0.0001,
help="Moe router aux loss. You can refer to STMoE for details.",
)
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.")
parser.add_argument(
"--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor."
)
# load balance
parser.add_argument(
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
)
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
# communicate overlap
parser.add_argument(
"--comm_overlap",
action="store_true",
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
test_mode = args.model_name == "test"
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": OpenMoeForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
}
mgr_dict = {
"seed": 42,
}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size // args.extra_dp_size,
use_ep_inside=use_ep_inside,
**mgr_dict,
)
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build OpenMoe model
if test_mode:
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
config.hidden_size = 128
config.intermediate_size = 256
config.vocab_size = 32000
else:
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
router_aux_loss_factor=args.router_aux_loss_factor,
router_z_loss_factor=args.router_z_loss_factor,
z_loss_factor=args.z_loss_factor,
enable_load_balance=args.load_balance,
enable_comm_overlap=args.comm_overlap,
enable_kernel=args.use_kernel,
)
with skip_init():
model = OpenMoeForCausalLM(config)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
if test_mode:
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
collate_fn = None
else:
dataset = load_dataset(args.dataset, args.task_name)
dataset = dataset["train"]
collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length)
dataloader = plugin.prepare_dataloader(
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
if not test_mode:
load_ckpt(repo_name, model, booster)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Start finetuning
coordinator.print_on_master(f"Start finetuning")
for epoch in range(args.num_epoch):
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter)
with tqdm(
range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master(),
) as pbar:
for step in pbar:
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
optimizer.zero_grad()
# Apply load balance
if (
args.load_balance
and args.load_balance_interval > 0
and (step + 1) % args.load_balance_interval == 0
):
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
# save ckeckpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
booster.save_model(model, args.output_path, shard=True)
# save checkpoint at the end of each epochs
booster.save_model(model, args.output_path, shard=True)
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
# Finish training
coordinator.print_on_master(f"Finish training")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,40 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
BATCH_SIZE=1
LR=0.00001
# ep zero
torchrun --standalone --nproc_per_node $NUM_GPU train.py \
--num_epoch 1 \
--model_name $MODEL \
--plugin "ep_zero" \
--batch_size $BATCH_SIZE \
--lr $LR \
--zero_stage 1 \
--extra_dp_size 2
# ep
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
# --num_epoch 1 \
# --model_name $MODEL \
# --plugin "ep_zero" \
# --batch_size $BATCH_SIZE \
# --lr $LR \
# --zero_stage 1
# hybrid
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
# --num_epoch 1 \
# --model_name $MODEL \
# --plugin "hybrid" \
# --batch_size $BATCH_SIZE \
# --lr $LR \
# --zero_stage 1 \
# --pp_size 2 \
# --dp_size 1 \
# --ep_size 2 \

View File

@ -2,4 +2,4 @@
markers =
dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs)
largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs)
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_fx --ignore=tests/test_legacy

View File

@ -0,0 +1,56 @@
import pytest
import torch
from packaging import version
from torch import nn
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
try:
import triton
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
BATCH_SIZE = 4
SEQ_LEN = 16
HIDDEN_SIZE = 32
def SwiGLU(x):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size = x.shape[-1]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = torch.split(x, size // 2, -1)
return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype))
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_llama_act_combine(dtype: str):
x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda()
x_gate_torch = nn.Parameter(x_gate.detach().clone())
x_gate_kernel = nn.Parameter(x_gate.detach().clone())
x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda()
x_up_torch = nn.Parameter(x_up.detach().clone())
x_up_kernel = nn.Parameter(x_up.detach().clone())
torch_out = SwiGLU(x_gate_torch) * x_up_torch
kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel)
atol = 1e-5 if dtype == torch.float32 else 5e-2
assert torch.allclose(torch_out, kernel_out, atol=atol)
torch_out.mean().backward()
kernel_out.mean().backward()
assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad])
assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol)
assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol)
if __name__ == '__main__':
test_llama_act_combine(torch.float16)

169
tests/test_moe/moe_utils.py Normal file
View File

@ -0,0 +1,169 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
class MoeModel(nn.Module):
def __init__(self, enable_load_balance: bool = False):
class TestSubModule(nn.Module):
def __init__(self):
super().__init__()
self.moe = SparseMLP(
num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance
)
self.proj = nn.Linear(16, 4)
def forward(self, x):
x = self.moe(x)
x = self.proj(x)
return x
super().__init__()
self.test_embed = nn.Linear(4, 16)
self.test_transform = TestSubModule()
def forward(self, x):
MOE_MANAGER.reset_loss()
x = self.test_embed(x)
x = self.test_transform(x)
return x
@GRADIENT_HANDLER.register_module
class MoeGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group and
moe model parallel. A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer=None):
super().__init__(model, optimizer)
def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
across moe model parallel group
"""
if dist.get_world_size() > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism
if 1 in epsize_param_dict:
bucket_allreduce(param_list=epsize_param_dict[1])
for ep_size in epsize_param_dict:
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
bucket_allreduce(
param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
)
def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
tp_model (MoeModule)
ep_model (MoeModule)
"""
for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
assert tp_name == ep_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
assert torch.allclose(tp_param, ep_param)
assert torch.allclose(tp_param.grad, ep_param.grad)
else:
tp_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
# get tp param
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2]
tp_rank = get_ep_rank(tp_param)
tp_dim = tp_dim[0] + 1
tp_slice = [slice(None)] * tp_dim + [
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
]
new_tp_param = all_param[tuple(tp_slice)]
if assert_grad_flag:
new_grad = all_grad[tuple(tp_slice)]
if assert_grad_flag:
assert torch.allclose(tp_param, new_tp_param)
assert torch.allclose(tp_param.grad, new_grad)
else:
tp_param.data.copy_(new_tp_param.data)
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name == ep_name
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param)
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def assert_not_equal_in_group(tensor, process_group=None):
# all gather tensors from different ranks
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group=process_group)
# check if they are equal one by one
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i + 1]
assert not torch.allclose(
a, b
), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"

View File

@ -4,40 +4,58 @@ import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.moe import sync_moe_model_param
from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group
BATCH_SIZE = 4
DIM = 16
CONFIG = dict()
def run_test(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
expert_module = nn.Linear
expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device())
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
MOE_CONTEXT.setup(42) # MOE initialization
noisy_func = UniformNoiseGenerator()
router = Top1Router(noisy_func=noisy_func)
MOE_MANAGER.setup(42, parallel="EP") # MOE initialization
num_experts_list = [1, 2, 4]
layer_list = []
for num_experts in num_experts_list:
exp = Experts(expert_module, num_experts, **expert_factor)
moe_layer = MoeLayer(DIM, num_experts, router, exp)
moe_layer = SparseMLP(
hidden_size=DIM,
intermediate_size=DIM * 4,
num_experts=num_experts,
router_top_k=1,
router_noisy_policy="Jitter",
)
layer_list.append(moe_layer)
model = nn.ModuleList(layer_list)
model = model.to(get_current_device())
dist_dict = MOE_MANAGER.parallel_info_dict
assert_not_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
assert_not_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
assert_not_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)
assert_not_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group)
assert_not_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group)
assert_not_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group)
sync_moe_model_param(model)
dist_dict = MOE_CONTEXT.parallel_info_dict
assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group)
assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)
assert_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group)
assert_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group)
assert_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group)
# MoE model synchronization passed
grad_handler = MoeGradientHandler(model, 0)
@ -47,17 +65,18 @@ def run_test(rank, world_size, port):
data = torch.randn(BATCH_SIZE, DIM, device=get_current_device())
grad = torch.randn_like(data)
MOE_CONTEXT.reset_loss()
MOE_MANAGER.reset_loss()
for layer in layer_list:
data, _ = layer(data)
data = layer(data)
data.backward(grad)
grad_handler.handle_gradient()
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group)
assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group)
assert_equal_in_group(layer_list[0].experts.wi.grad, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[0].experts.wo.grad, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.wi.grad, dist_dict[2].dp_group)
assert_equal_in_group(layer_list[1].experts.wo.grad, dist_dict[2].dp_group)
assert_equal_in_group(layer_list[2].experts.wi.grad, dist_dict[4].dp_group)
assert_equal_in_group(layer_list[2].experts.wo.grad, dist_dict[4].dp_group)
# MoE grad handler test passed

View File

@ -1,49 +1,47 @@
import pytest
import torch
import torch.nn as nn
import torch.distributed as dist
import colossalai
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
BATCH_SIZE = 16
BATCH_SIZE = 4
NUM_EXPERTS = 4
CONFIG = dict()
def check_equal(tensor_a, tensor_b, atol=1e-06):
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router):
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1):
# Here we do not need TF32, since it brings absolute error on results
torch.backends.cuda.matmul.allow_tf32 = False
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
local_rank = dist.get_rank()
MOE_CONTEXT.setup(42) # MOE environment initialization
MOE_CONTEXT.reset_loss()
torch.manual_seed(rs + local_rank) # set each process has different random seed
MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization
MOE_MANAGER.reset_loss()
torch.manual_seed(rs + local_rank) # set each process has different random seed
# get randomized data
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
expert_module = nn.Linear
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
layer = SparseMLP(hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_experts=NUM_EXPERTS,
router_top_k=topk,
router_capacity_factor_train=1.0)
layer = layer.to(get_current_device())
if data_type == torch.float16:
layer = layer.half()
# use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine
layer.use_kernel = False
old_out, _ = layer(tokens)
layer.enable_kernel = False
old_out = layer(tokens)
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad) # get gradient
@ -56,8 +54,8 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
tokens.grad.zero_()
layer.gate_weight.grad.zero_()
layer.use_kernel = True
new_out, _ = layer(tokens) # get outputs through colossal kernel
layer.enable_kernel = True
new_out = layer(tokens) # get outputs through colossal kernel
if data_type == torch.float32:
check_equal(old_out, new_out)
@ -86,11 +84,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("rs", [131])
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
@pytest.mark.parametrize("router", [Top1Router, Top2Router])
@pytest.mark.parametrize("topk", [1, 2])
@rerun_if_address_is_in_use()
def test_moe_kernel(rs, hidden_size, data_type, router):
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router)
def test_moe_kernel(rs, hidden_size, data_type, topk):
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
if __name__ == "__main__":
test_moe_kernel(2, 256, torch.float16, Top2Router)
if __name__ == '__main__':
test_moe_kernel(2, 256, torch.float16, 2)

View File

@ -1,50 +1,138 @@
import importlib
import os
import shutil
import sys
import pytest
import torch
import torch.distributed as dist
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer.moe import load_moe_model, save_moe_model
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_zero.test_legacy.common import CONFIG
sys.path.append(os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"examples/language/openmoe",
))
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
def exam_moe_checkpoint():
with ColoInitContext(device=get_current_device()):
model = MoeModel(checkpoint=True)
save_moe_model(model, "temp_path.pth")
def get_config():
config = LlamaConfig(
vocab_size=300,
hidden_size=16,
intermediate_size=32,
num_hidden_layers=4,
num_attention_heads=2,
head_dim=4,
dropout_rate=0.0,
hidden_act="swiglu",
)
set_openmoe_args(config, num_experts=16, moe_layer_interval=1)
return config
with ColoInitContext(device=get_current_device()):
other_model = MoeModel(checkpoint=True)
load_moe_model(other_model, "temp_path.pth")
state_0 = model.state_dict()
state_1 = other_model.state_dict()
for k, v in state_0.items():
u = state_1.get(k)
def get_model(parallel):
config = get_config()
model = OpenMoeForCausalLM(config)
if parallel == None:
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
zero_stage=0,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "zero_ep":
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=2,
zero_stage=1,
microbatch_size=1,
custom_policy=OpenMoeForCausalLMPolicy(),
)
booster = Booster(plugin=plugin)
model, _, _, _, _ = booster.boost(model=model)
return model, booster
def _test_moe_checkpoint(parallel, shard):
if parallel == None:
MOE_MANAGER.setup(
seed=42,
parallel=None,
)
elif parallel == "zero2_ep":
MOE_MANAGER.setup(
seed=42,
parallel="EP",
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
seed=42,
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1 = get_model(parallel)
model2, booster2 = get_model(parallel)
if shard:
booster1.save_model(model1, "./tmp_ckpt", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt")
else:
booster1.save_model(model1, "tmp_ckpt.pth")
booster2.load_model(model2, "tmp_ckpt.pth")
state1 = model1.state_dict()
state2 = model2.state_dict()
for k, v in state1.items():
u = state2.get(k)
assert torch.equal(u.data, v.data)
if dist.get_rank() == 0:
os.remove("temp_path.pth")
if shard:
shutil.rmtree("./tmp_ckpt")
else:
os.remove("tmp_ckpt.pth")
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_CONTEXT.setup(seed=42)
exam_moe_checkpoint()
def _run_dist(rank, world_size, port, parallel, shard):
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
_test_moe_checkpoint(parallel, shard)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "zero_ep", "hybrid"])
@pytest.mark.parametrize("shard", [True, False])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size):
spawn(_run_dist)
def test_moe_checkpoint(world_size, parallel, shard):
spawn(_run_dist, world_size, parallel=parallel, shard=shard)
if __name__ == "__main__":
test_moe_checkpoint(world_size=4)
test_moe_checkpoint(world_size=4, parallel="hybrid", shard=True)

View File

@ -1,55 +0,0 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.tensor import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_zero.test_legacy.common import CONFIG
@parameterize("init_device_type", ["cpu", "cuda"])
def exam_moe_colo_init(init_device_type):
world_size = dist.get_world_size()
if init_device_type == "cuda":
init_device = get_current_device()
elif init_device_type == "cpu":
init_device = torch.device("cpu")
else:
raise NotImplementedError("Unknown device found.")
with ColoInitContext(device=init_device):
model = MoeModel(checkpoint=True)
for name, param in model.named_parameters():
assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name)
if hasattr(param, "moe_info"):
param.set_process_group(param.moe_info.pg)
if hasattr(param, "moe_info"):
assert param.process_group.dp_world_size() == param.moe_info.dp_size
else:
assert param.process_group.dp_world_size() == world_size
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_CONTEXT.setup(seed=42)
exam_moe_colo_init()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_colo_init(world_size):
spawn(_run_dist, world_size)
if __name__ == "__main__":
test_moe_colo_init(world_size=4)

View File

@ -0,0 +1,81 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
assert batch_size % world_size == 0
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel=None)
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel="EP")
ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_current_device())
tp_model = tp_model.to(get_current_device())
local_model = local_model.to(get_current_device())
# sync ep param
sync_moe_model_param(ep_model)
dist_dict = MOE_MANAGER.parallel_info_dict
assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
grad_handler = MoeGradientHandler(ep_model)
# sync tp param
sync_tp_from_ep(tp_model, ep_model)
# sync local param
sync_local_from_ep(local_model, ep_model)
rank = dist.get_rank()
torch.cuda.manual_seed(seed)
tp_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size
ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)]
out_local = local_model(tp_data)
MOE_MANAGER.reset_loss()
out_tp = tp_model(tp_data)
MOE_MANAGER.reset_loss()
out_ep = ep_model(ep_data)
MOE_MANAGER.reset_loss()
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)])
out_local.mean().backward()
out_tp.mean().backward()
out_ep.mean().backward()
grad_handler.handle_gradient()
assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
@pytest.mark.dist
@pytest.mark.parametrize("num_experts", [4, 8])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("dim", [32])
@pytest.mark.parametrize("seed", [42])
@rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
if __name__ == '__main__':
test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42)

View File

@ -3,66 +3,80 @@ import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.nn.layer.moe import Experts
from colossalai.moe.experts import MLPExperts
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.moe import sync_moe_model_param
D_MODEL = 4
D_FF = 8
CONFIG = dict()
HIDDEN_SIZE = 4
INTERMEDIATE_SIZE = 8
def run_test(rank, world_size, port):
world_size = 4
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
expert_module = nn.Linear
expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())
def run_moe_init(expert_parallel):
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=expert_parallel)
expert_args = dict(
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
expert_parallel=expert_parallel,
)
exp0 = MLPExperts(1, **expert_args)
exp1 = MLPExperts(2, **expert_args)
exp2 = MLPExperts(4, **expert_args)
MOE_CONTEXT.setup(42) # MOE environment initialization
exp0 = Experts(expert_module, 1, **expert_factor)
exp1 = Experts(expert_module, 2, **expert_factor)
exp2 = Experts(expert_module, 4, **expert_factor)
exp3 = Experts(expert_module, 8, **expert_factor)
if expert_parallel == "EP":
assert exp0.num_local_experts == 1
assert exp1.num_local_experts == 1
assert exp2.num_local_experts == 2
else:
assert exp0.num_local_experts == 1
assert exp1.num_local_experts == 2
assert exp2.num_local_experts == 4
assert exp0.num_local_experts == 1
assert exp1.num_local_experts == 1
assert exp2.num_local_experts == 1
assert exp3.num_local_experts == 2
# experts deployment passed
parallel_info_dict = MOE_CONTEXT.parallel_info_dict
parallel_info_dict = MOE_MANAGER.parallel_info_dict
rank = dist.get_rank()
assert len(parallel_info_dict) == 3
assert dist.get_rank(parallel_info_dict[4].ep_group) == rank
# group creation assert
assert len(parallel_info_dict) == 2
assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2
assert dist.get_rank(parallel_info_dict[1].ep_group) == 0
assert dist.get_rank(parallel_info_dict[4].dp_group) == 0
assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2
assert dist.get_rank(parallel_info_dict[1].dp_group) == rank
# group creation passed
model = nn.ModuleList([exp0, exp1, exp2, exp3])
model = nn.ModuleList([exp0, exp1, exp2])
model = model.to(get_current_device())
sync_moe_model_param(model)
assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group)
assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group)
# MOE experts layout success when ep_size = 1
assert_equal_in_group(exp0.wi.data, parallel_info_dict[1].dp_group)
assert_equal_in_group(exp0.wo.data, parallel_info_dict[1].dp_group)
assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group)
assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group)
# MOE experts layout success when ep_size = 2
assert_equal_in_group(exp1.wi.data, parallel_info_dict[2].dp_group)
assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group)
def _run_test(rank, world_size, port, expert_parallel):
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_moe_init(expert_parallel)
@pytest.mark.dist
@pytest.mark.parametrize("expert_parallel", ["EP", "TP"])
@rerun_if_address_is_in_use()
def test_moe_initialization():
spawn(run_test, 4)
def test_moe_initialization(expert_parallel):
spawn(_run_test, 2, expert_parallel=expert_parallel)
if __name__ == "__main__":
test_moe_initialization()
test_moe_initialization("EP")
test_moe_initialization("TP")

View File

@ -0,0 +1,97 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeModel
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss / 2)
else:
loss.backward()
return y
def run_zero_optim_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss()
data = torch.randn(16, 4).cuda()
label = torch.randint(0, 4, (16,)).cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=None)
torch_model = MoeModel()
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP")
zero_model = MoeModel()
extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group
ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group)
ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
if is_moe_tensor(zero_param):
num_expert = torch_param.data.shape[0]
zero_param.data.copy_(
torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)]
.detach()
.clone()
)
else:
zero_param.data.copy_(torch_param.data.detach().clone())
zero_optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group
booster = Booster(plugin=plugin)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
run_fwd_bwd(torch_model, data, label, criterion, None)
torch_optimizer.step()
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
zero_optimizer.step()
for (torch_name, torch_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.named_parameters()
):
if is_moe_tensor(zero_param):
num_expert = torch_param.data.shape[0]
torch_param.data = torch_param.data[
ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)
]
assert torch.allclose(
torch_param.data, zero_param.data, atol=1e-4
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_optim_test(rank, world_size, stage=1)
run_zero_optim_test(rank, world_size, stage=2)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_zero_optim(world_size=4)

View File

@ -0,0 +1,190 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def run_zero_optim_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(
seed=42,
parallel="EP",
)
zero_model = MoeModel(enable_load_balance=True)
zero_optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True)
booster = Booster(plugin=plugin)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel="EP")
torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda().bfloat16()
grad_handler = MoeGradientHandler(torch_model)
# run to update expert load
data = torch.randn(16, 4).cuda().bfloat16() / 1000 / (local_rank + 1)
label = torch.randint(0, 4, (16,)).cuda()
# run torch model twice
run_fwd_bwd(torch_model, data, label, criterion, None)
grad_handler.handle_gradient()
torch_optimizer.step()
torch_optimizer.zero_grad()
run_fwd_bwd(torch_model, data, label, criterion, None)
grad_handler.handle_gradient()
# get optim and load status in zero model
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
zero_optimizer.step()
zero_optimizer.zero_grad()
with torch.no_grad():
origin_out = zero_model(data)
# load balance
apply_load_balance(zero_model, zero_optimizer)
# run again to test
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
torch.allclose(origin_out, zero_out)
# assert optim
torch_optimizer.step()
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
zero_optimizer.step()
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}"
def run_hybrid_zero_optim_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss()
data = torch.randn(16, 4).cuda()
label = torch.randint(0, 4, (16,)).cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=None)
torch_model = MoeModel()
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(
seed=42,
max_ep_size=2,
use_ep_inside=False,
parallel="EP",
)
zero_model = MoeModel(enable_load_balance=True)
extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group
ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group)
ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
if is_moe_tensor(zero_param):
num_expert = torch_param.data.shape[0]
zero_param.data.copy_(
torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)]
.detach()
.clone()
)
else:
zero_param.data.copy_(torch_param.data.detach().clone())
zero_optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group
booster = Booster(plugin=plugin)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
# run torch for twice
run_fwd_bwd(torch_model, data, label, criterion, None)
torch_optimizer.step()
torch_optimizer.zero_grad()
run_fwd_bwd(torch_model, data, label, criterion, None)
torch_optimizer.step()
# run zero
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
zero_optimizer.step()
zero_optimizer.zero_grad()
with torch.no_grad():
origin_out = zero_model(data)
# load balance
apply_load_balance(zero_model, zero_optimizer)
# assert out
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
torch.allclose(origin_out, zero_out)
# assert optim
zero_optimizer.step()
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
# TODO: high atol, check if bug exists
assert torch.allclose(zero_out, torch_out, atol=8e-4), f"zero_out:{zero_out}\ntorch_out{torch_out}"
def run_dist(rank, world_size, port):
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_zero_optim_test(rank, world_size, stage=1)
run_zero_optim_test(rank, world_size, stage=2)
run_hybrid_zero_optim_test(rank, world_size, stage=1)
run_hybrid_zero_optim_test(rank, world_size, stage=2)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_load_balance(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_load_balance(world_size=4)

View File

@ -0,0 +1,41 @@
import pytest
import torch
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
@pytest.mark.parametrize(["router", "num_groups"], [
(Top1Router(), 1),
(Top2Router(), 1),
(TopKRouter(num_selected_experts=3), 4),
])
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
(4, 5, 8),
(3, 4, 4),
])
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
if num_groups > 1:
x = x.expand(num_groups, -1, -1)
router.train()
if isinstance(router, TopKRouter):
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
combine_array, dispatch_mask = router(x)
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
router.eval()
if isinstance(router, TopKRouter):
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
combine_array, dispatch_mask = router(x)
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
if __name__ == "__main__":
test_router_forward(Top1Router(), 4, 4, 4, 1)

View File

@ -0,0 +1,105 @@
import pytest
import torch
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def run_zero_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss()
zero_model = MoeModel()
optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
booster = Booster(plugin=plugin)
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer)
torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_model = torch_model.cuda()
grad_handler = MoeGradientHandler(torch_model)
# assert zero model
for (torch_name, torch_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.module.named_parameters()
):
assert zero_name == torch_name
assert torch.allclose(zero_param.data, torch_param.data)
data = torch.randn(16, 4).cuda()
label = torch.randint(0, 4, (16,)).cuda()
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer)
assert torch.allclose(torch_out, zero_out)
grad_handler.handle_gradient()
for (zero_name, zero_param), (torch_name, torch_param) in zip(
zero_model.module.named_parameters(), torch_model.named_parameters()
):
assert zero_name == torch_name
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
if hasattr(zero_param, "moe_info"):
assert len(zero_grad_list) == 0
assert torch.allclose(zero_param.grad, torch_param.grad)
else:
assert len(zero_grad_list) > 0
torch_grad_list = split_ddp_grad(torch_param.grad, world_size)
if stage == 2:
torch_grad_list = torch_grad_list[local_rank : local_rank + 1]
assert len(zero_grad_list) == len(torch_grad_list)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
assert torch.allclose(zero_grad, torch_grad)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(seed=42, parallel="EP")
seed_all(42 + rank)
run_zero_test(rank, world_size, stage=1)
run_zero_test(rank, world_size, stage=2)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_zero_model(world_size=2)

View File

@ -1,106 +0,0 @@
import pytest
import torch
import torch.nn as nn
import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.logging import get_dist_logger
from colossalai.nn import CheckpointModule
from colossalai.nn.layer import MoeModule
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero.legacy.init_ctx import ZeroInitContext
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from tests.test_zero.test_legacy.common import CONFIG
class MoeModel(nn.Module):
def __init__(self, checkpoint: bool = False):
class TestSubModule(CheckpointModule):
def __init__(self):
super().__init__(checkpoint)
expert_cls = nn.Linear
expert_args_dict = dict(in_features=16, out_features=16)
self.moe = MoeModule(
dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict
)
self.proj = nn.Linear(16, 4)
def _forward(self, x):
x, y = self.moe(x)
x = self.proj(x)
return x, y
super().__init__()
self.test_embed = nn.Linear(4, 16)
self.test_transform = TestSubModule()
def forward(self, x):
MOE_CONTEXT.reset_loss()
x = self.test_embed(x)
x, y = self.test_transform(x)
MOE_CONTEXT.add_loss(y)
return x
@parameterize("init_device_type", ["cpu", "cuda"])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_moe_zero_init(init_device_type, shard_strategy_class):
get_dist_logger("test_moe_zero_init")
if init_device_type == "cuda":
init_device = get_current_device()
elif init_device_type == "cpu":
init_device = torch.device("cpu")
else:
raise NotImplementedError("Unknown device found.")
model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(
target_device=init_device,
shard_strategy=shard_strategy_class(),
shard_param=True,
model_numel_tensor=model_numel_tensor,
):
model = MoeModel(checkpoint=True)
for name, param in model.named_parameters():
assert hasattr(param, "colo_attr")
# the parameters in moe experts and its gate should not be sharded
if ("experts" in name) or ("gate" in name) or ("residual_combine" in name):
assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
else:
assert param.colo_attr.sharded_data_tensor.is_sharded
# the parameters in moe experts is not replicated
if "experts" in name:
assert not param.colo_attr.is_replicated
else:
assert param.colo_attr.is_replicated
if param.colo_attr.param_is_sharded:
assert (
param.colo_attr.data_payload.device.type == init_device.type
), f"{param.colo_attr.data_payload.device.type} vs. {init_device.type}"
else:
assert param.colo_attr.data_payload.device.type == "cuda"
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_CONTEXT.setup(seed=42)
run_moe_zero_init()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_moe_zero_init(world_size):
spawn(_run_dist, world_size)
if __name__ == "__main__":
test_moe_zero_init(world_size=2)

View File

@ -1,70 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero.legacy.init_ctx import ZeroInitContext
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from colossalai.zero.legacy.sharded_model import ShardedModelV2
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_fwd_bwd
@parameterize("enable_autocast", [False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy_class):
shard_strategy = shard_strategy_class()
get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model")
_, train_dataloader, _, optimizer_class, _ = get_components_func()
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
with ZeroInitContext(
target_device=torch.device("cuda", torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True
):
zero_model = MoeModel(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy)
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.data_payload)
model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
grad_handler = MoeGradientHandler(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 5:
break
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, enable_autocast)
run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
grad_handler.handle_gradient()
check_grads_padding(model, zero_model, loose=True)
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_CONTEXT.setup(seed=42)
run_model_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_zero_model(world_size=2)

View File

@ -2,120 +2,91 @@ import pytest
import torch
import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero.legacy.init_ctx import ZeroInitContext
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from colossalai.zero.legacy.sharded_model import ShardedModelV2
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
from colossalai.zero.low_level._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
def _run_step(model, optimizer, data, label, criterion, grad_handler):
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
if grad_handler is not None:
def run_zero_optim_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss()
zero_model = MoeModel()
zero_optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
booster = Booster(plugin=plugin)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
grad_handler = MoeGradientHandler(torch_model)
for _ in range(2):
data = torch.randn(16, 4).cuda() / (local_rank + 1)
label = torch.randint(0, 4, (16,)).cuda()
run_fwd_bwd(torch_model, data, label, criterion, None)
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
grad_handler.handle_gradient()
optimizer.step()
torch_optimizer.step()
zero_optimizer.step()
for (torch_name, torch_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.named_parameters()
):
assert torch.allclose(
torch_param.data, zero_param.data
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
torch_optimizer.zero_grad()
zero_optimizer.zero_grad()
@parameterize("cpu_offload", [True])
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
@parameterize("reuse_fp16_shard", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(
cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0
):
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
MOE_CONTEXT.reset_loss()
get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model")
_, train_dataloader, _, optimizer_class, _ = get_components_func()
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
with ZeroInitContext(
target_device=torch.device("cpu") if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True,
):
zero_model = MoeModel(checkpoint=True)
zero_model = ShardedModelV2(
zero_model,
shard_strategy,
tensor_placement_policy="cpu" if cpu_offload else "cuda",
reuse_fp16_shard=reuse_fp16_shard,
)
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device()))
model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda().float()
if use_cpuadam:
optimizer_class = CPUAdam
optim = optimizer_class(model.parameters(), lr=1e-3)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(
zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio
)
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
apex_grad_handler = MoeGradientHandler(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 5:
break
data, label = data.cuda(), label.cuda()
_run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler)
_run_step(zero_model, sharded_optim, data, label, criterion, None)
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
for param in model.parameters():
assert not has_inf_or_nan(param)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(seed=42, parallel="EP")
run_zero_optim_test(rank, world_size, stage=1)
run_zero_optim_test(rank, world_size, stage=2)
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_CONTEXT.setup(seed=42)
_run_test_sharded_optim_v2()
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size):
spawn(_run_dist, world_size)
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_zero_optim(world_size=4)
test_moe_zero_optim(world_size=2)