mirror of https://github.com/hpcaitech/ColossalAI
parent
8993c8a817
commit
dc003c304c
|
@ -0,0 +1,382 @@
|
|||
import random
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
HybridParallelAMPOptimizer,
|
||||
HybridParallelModule,
|
||||
HybridParallelNaiveOptimizer,
|
||||
HybridParallelPlugin,
|
||||
get_param_info,
|
||||
init_pipeline_optimizer,
|
||||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MoeCheckpintIO
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
|
||||
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None,
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
|
||||
):
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
verbose=verbose,
|
||||
reduce_bucket_size=reduce_bucket_size,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
partition_grad=partition_grad,
|
||||
cpu_offload=cpu_offload,
|
||||
dp_process_group=dp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
moe_extra_dp_process_group=moe_extra_dp_process_group,
|
||||
)
|
||||
|
||||
|
||||
class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
"""
|
||||
Plugin for Moe Hybrid Parallel Training.
|
||||
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
|
||||
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
|
||||
|
||||
Example:
|
||||
>>> from colossalai.booster import Booster
|
||||
>>> from colossalai.booster.plugin import HybridParallelPlugin
|
||||
|
||||
>>> model, train_dataset, optimizer, criterion = ...
|
||||
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
|
||||
|
||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
>>> booster = Booster(plugin=plugin)
|
||||
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
Args:
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
|
||||
When set to 0, ZeRO will not be used. Defaults to 0.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
|
||||
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
||||
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
||||
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
||||
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
|
||||
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
|
||||
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
|
||||
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
|
||||
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
|
||||
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
|
||||
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
|
||||
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
|
||||
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
|
||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
|
||||
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
|
||||
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
|
||||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
extra_dp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
ddp_bucket_cap_mb: int = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
use_ep_inside: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
) -> None:
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
self.enable_all_optimization = enable_all_optimization
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
||||
|
||||
# sync moe in outer dp group, and sync other param in global dp group
|
||||
if extra_dp_size > 1:
|
||||
ep_size = self.dp_size // extra_dp_size
|
||||
if use_ep_inside:
|
||||
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
|
||||
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
|
||||
else:
|
||||
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
|
||||
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
|
||||
else:
|
||||
self.moe_extra_dp_group = None
|
||||
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
|
||||
self.ddp_config = dict(
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
bucket_cap_mb=ddp_bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
|
||||
self.zero_config = dict(
|
||||
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
)
|
||||
|
||||
self.max_norm = max_norm
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoeCheckpintIO:
|
||||
self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
|
||||
def configure(
|
||||
self,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
param_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(
|
||||
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config,
|
||||
)
|
||||
self.checkpoint_io.link_master_and_working_param(
|
||||
optimizer.working_to_master_map, optimizer.master_to_working_map
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
moe_extra_dp_process_group=self.moe_extra_dp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
**self.amp_config,
|
||||
)
|
||||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
|
@ -1,7 +1,5 @@
|
|||
from .config import Config, ConfigException
|
||||
|
||||
# from .moe_context import MOE_CONTEXT
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"ConfigException",
|
||||
|
|
|
@ -1,132 +0,0 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.legacy.tensor import ProcessGroup
|
||||
|
||||
|
||||
def _check_sanity():
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
|
||||
|
||||
|
||||
class MoeParallelInfo:
|
||||
"""Moe parallelism information, storing parallel sizes and groups."""
|
||||
|
||||
def __init__(self, ep_size: int, dp_size: int):
|
||||
_check_sanity()
|
||||
self.ep_size = ep_size
|
||||
self.dp_size = dp_size
|
||||
self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
|
||||
self.ep_group = self.pg.tp_process_group()
|
||||
self.dp_group = self.pg.dp_process_group()
|
||||
|
||||
|
||||
class MoeContext(metaclass=SingletonMeta):
|
||||
"""MoE parallel context manager. This class manages different
|
||||
parallel groups in MoE context and MoE loss in training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.world_size = 1
|
||||
# Users may want to set maximum expert parallel size smaller than the world size
|
||||
# since very low bandwidth across nodes may constrain the performance of MoE
|
||||
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
|
||||
self.max_ep_size = 1
|
||||
self.min_dp_size = 1
|
||||
self.aux_loss = None
|
||||
self.use_kernel_optim = True
|
||||
|
||||
self.has_setup = False
|
||||
self._parallel_info_dict = dict()
|
||||
|
||||
@property
|
||||
def parallel_info_dict(self):
|
||||
return self._parallel_info_dict
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self.has_setup
|
||||
|
||||
def setup(self, seed: int, use_kernel_optim: bool = True):
|
||||
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
|
||||
_check_sanity()
|
||||
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
|
||||
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
|
||||
assert (
|
||||
self.world_size % self.max_ep_size == 0
|
||||
), "Maximum expert parallel size must be a factor of the number of GPUs"
|
||||
self.min_dp_size = self.world_size // self.max_ep_size
|
||||
|
||||
# Enabling kernel optimization may raise error in some cases
|
||||
# Users can close kernel optimization manually
|
||||
self.use_kernel_optim = use_kernel_optim
|
||||
|
||||
from .random import moe_set_seed
|
||||
|
||||
moe_set_seed(seed)
|
||||
self.has_setup = True
|
||||
|
||||
def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
|
||||
"""Calculate the Data Parallel Group and Expert Parallel Group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_experts : int
|
||||
The number experts
|
||||
|
||||
Returns
|
||||
-------
|
||||
int, MoeParallelInfo
|
||||
number of local experts, the MoeParallelInfo of the current ep_size
|
||||
"""
|
||||
|
||||
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
|
||||
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
|
||||
|
||||
assert gt_flag or lt_flag, (
|
||||
"Automatic experts placement dose not not support expert number"
|
||||
" is not a multiple of ep size or vice versa."
|
||||
)
|
||||
|
||||
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
|
||||
# there are multiple experts in each GPU and each GPU has different experts
|
||||
# So it's data parallel size is 1
|
||||
# Otherwise, there is only one expert in each GPU
|
||||
# The data parallel size should be calculated
|
||||
dp_size = 1 if gt_flag else self.max_ep_size // num_experts
|
||||
ep_size = self.max_ep_size // dp_size
|
||||
|
||||
# Calculate the number of experts for each GPU
|
||||
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
|
||||
|
||||
# Don't forget to multiply minimum data parallel size
|
||||
dp_size *= self.min_dp_size
|
||||
if not (ep_size in self.parallel_info_dict):
|
||||
self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
|
||||
|
||||
return num_local_experts, self.parallel_info_dict[ep_size]
|
||||
|
||||
def set_kernel_not_use(self):
|
||||
self.use_kernel_optim = False
|
||||
|
||||
def reset_loss(self):
|
||||
self.aux_loss = 0
|
||||
|
||||
def add_loss(self, loss):
|
||||
self.aux_loss += loss
|
||||
|
||||
def get_loss(self):
|
||||
return self.aux_loss
|
||||
|
||||
|
||||
MOE_CONTEXT = MoeContext()
|
|
@ -0,0 +1,185 @@
|
|||
from functools import reduce
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
PRECISION_MAP = {
|
||||
"fp32": (0, torch.float32),
|
||||
"fp16": (1, torch.float16),
|
||||
"bf16": (2, torch.bfloat16),
|
||||
}
|
||||
|
||||
@triton.jit
|
||||
def _llama_act_combine_forward(
|
||||
X_GATE1,
|
||||
X_GATE2,
|
||||
X_UP,
|
||||
Y,
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
X_GATE1 += row * stride
|
||||
X_GATE2 += row * stride
|
||||
X_UP += row * stride
|
||||
Y += row * stride
|
||||
|
||||
# do activation and combine, and store in y
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def _llama_act_combine_backward(
|
||||
X_GATE1,
|
||||
X_GATE2,
|
||||
X_UP,
|
||||
X_GATE1_GRAD,
|
||||
X_GATE2_GRAD,
|
||||
X_UP_GRAD,
|
||||
Y_GRAD,
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
X_GATE1 += row * stride
|
||||
X_GATE2 += row * stride
|
||||
X_UP += row * stride
|
||||
X_GATE1_GRAD += row * stride
|
||||
X_GATE2_GRAD += row * stride
|
||||
X_UP_GRAD += row * stride
|
||||
Y_GRAD += row * stride
|
||||
|
||||
# do activation and combine, and store in y
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
||||
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
|
||||
|
||||
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
|
||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||
x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid
|
||||
x_up_grad = x_gate2_act * x_gate1
|
||||
x_gate1_grad = x_gate2_act * x_up
|
||||
# grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)]
|
||||
# = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]}
|
||||
x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid))
|
||||
|
||||
# Write output
|
||||
tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask)
|
||||
tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask)
|
||||
tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask)
|
||||
|
||||
class LlamaActCombine(torch.autograd.Function):
|
||||
"""
|
||||
act(x_gate) * x_up
|
||||
|
||||
Args:
|
||||
x_gate (torch.Tensor): (b, l, 2d) x_gate
|
||||
x_up (torch.Tensor): (b, l, d) x_up
|
||||
activation (str): only support swiglu
|
||||
precision (str): fp32, fp16, bf16
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor:
|
||||
"""
|
||||
act(x_gate) * x_up
|
||||
|
||||
Args:
|
||||
x_gate (torch.Tensor): (b, l, 2d) x gate
|
||||
x_up (torch.Tensor): (b, l, d) x up
|
||||
activation (str): only support swiglu
|
||||
"""
|
||||
assert activation == "swiglu", "Only swiglu is supported"
|
||||
|
||||
# split x gate
|
||||
assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2"
|
||||
x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1)
|
||||
x_gate1 = x_gate1.contiguous()
|
||||
x_gate2 = x_gate2.contiguous()
|
||||
if not x_up.is_contiguous():
|
||||
x_up = x_up.contiguous()
|
||||
# assert shape
|
||||
assert x_gate1.shape == x_gate2.shape == x_up.shape
|
||||
|
||||
# add ctx for backward
|
||||
if x_gate.requires_grad:
|
||||
ctx.save_for_backward(x_gate1, x_gate2, x_up)
|
||||
|
||||
# allocate output
|
||||
y = torch.empty_like(x_up)
|
||||
M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1]
|
||||
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x_gate.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# restore setting
|
||||
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
|
||||
# enqueue kernel
|
||||
_llama_act_combine_forward[(M,)](x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
y,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]:
|
||||
# restore from ctx
|
||||
(x_gate1, x_gate2, x_up) = ctx.saved_tensors
|
||||
M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps
|
||||
|
||||
# init grad
|
||||
y_grad = grad_outputs[0]
|
||||
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
|
||||
x_gate2), torch.empty_like(x_up)
|
||||
|
||||
# enqueue kernel
|
||||
_llama_act_combine_backward[(M,)](x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
x_gate1_grad,
|
||||
x_gate2_grad,
|
||||
x_up_grad,
|
||||
y_grad,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
|
||||
return x_gate_grad, x_up_grad, None, None
|
|
@ -1,6 +1,5 @@
|
|||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
||||
from ._moe_gradient_handler import MoeGradientHandler
|
||||
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
|
||||
from ._zero_gradient_handler import ZeROGradientHandler
|
||||
|
@ -10,6 +9,5 @@ __all__ = [
|
|||
"DataParallelGradientHandler",
|
||||
"ZeROGradientHandler",
|
||||
"PipelineSharedModuleGradientHandler",
|
||||
"MoeGradientHandler",
|
||||
"SequenceParallelGradientHandler",
|
||||
]
|
||||
|
|
|
@ -16,7 +16,6 @@ from torch.optim.optimizer import Optimizer
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.context import Config, ConfigException
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
|
||||
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
|
||||
|
@ -36,7 +35,6 @@ from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
|
|||
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
@ -323,8 +321,6 @@ def initialize(
|
|||
if not use_zero:
|
||||
if is_using_sequence():
|
||||
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||
elif MOE_CONTEXT.is_initialized:
|
||||
sync_moe_model_param(model)
|
||||
elif is_using_ddp():
|
||||
sync_model_param(model, ParallelMode.DATA)
|
||||
else:
|
||||
|
@ -377,14 +373,6 @@ def initialize(
|
|||
"added even though not specified in the configuration",
|
||||
ranks=[0],
|
||||
)
|
||||
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
|
||||
gradient_handler_cfg = [dict(type="MoeGradientHandler")]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0],
|
||||
)
|
||||
elif is_using_sequence():
|
||||
model = DDP(
|
||||
model,
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
from .checkpoint import MoeCheckpintIO
|
||||
from .experts import MLPExperts
|
||||
from .layers import SparseMLP
|
||||
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||
|
||||
__all__ = [
|
||||
"MLPExperts",
|
||||
"MoeRouter",
|
||||
"Top1Router",
|
||||
"Top2Router",
|
||||
"TopKRouter",
|
||||
"NormalNoiseGenerator",
|
||||
"UniformNoiseGenerator",
|
||||
"SparseMLP",
|
||||
"MoeCheckpintIO",
|
||||
]
|
|
@ -0,0 +1,275 @@
|
|||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
MOE_KERNEL = None
|
||||
|
||||
|
||||
def load_moe():
|
||||
global MOE_KERNEL
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
|
||||
MOE_KERNEL = MOEBuilder().load()
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0), None
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
return (
|
||||
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0), None
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
# TODO: support async backward
|
||||
return (
|
||||
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||
operation in torch.distributed.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
if dist.get_world_size(group) == 1:
|
||||
return inputs, None
|
||||
output = torch.empty_like(inputs)
|
||||
if not overlap:
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output, None
|
||||
else:
|
||||
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
|
||||
return output, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
return (
|
||||
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class MoeDispatch(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, tokens, mask, dest_idx, ec):
|
||||
s = tokens.size(0)
|
||||
h = tokens.size(1)
|
||||
dtype = tokens.dtype
|
||||
|
||||
if MOE_KERNEL is None:
|
||||
load_moe()
|
||||
if tokens.dtype != torch.float32:
|
||||
tokens = tokens.to(torch.float32)
|
||||
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
if expert_input.dtype != dtype:
|
||||
expert_input = expert_input.to(dtype)
|
||||
ctx.save_for_backward(mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.h = h
|
||||
ctx.ec = ec
|
||||
ctx.dtype = dtype
|
||||
|
||||
return expert_input
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad):
|
||||
mask, dest_idx = ctx.saved_tensors
|
||||
if output_grad.dtype != torch.float32:
|
||||
output_grad = output_grad.to(torch.float32)
|
||||
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
if d_tokens.dtype != ctx.dtype:
|
||||
d_tokens = d_tokens.to(ctx.dtype)
|
||||
return d_tokens, None, None, None
|
||||
|
||||
|
||||
class MoeCombine(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||||
assert logits.dtype == torch.float32
|
||||
|
||||
s = logits.size(0)
|
||||
e = logits.size(1)
|
||||
c = ec // e
|
||||
h = expert_tokens.size(-1)
|
||||
dtype = expert_tokens.dtype
|
||||
|
||||
if expert_tokens.dtype != torch.float32:
|
||||
expert_tokens = expert_tokens.to(torch.float32)
|
||||
if MOE_KERNEL is None:
|
||||
load_moe()
|
||||
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
|
||||
if output.dtype != dtype:
|
||||
output = output.to(dtype)
|
||||
|
||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.e = e
|
||||
ctx.c = c
|
||||
ctx.h = h
|
||||
ctx.dtype = dtype
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, tokens_grad):
|
||||
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||||
if tokens_grad.dtype != torch.float32:
|
||||
tokens_grad = tokens_grad.to(torch.float32)
|
||||
|
||||
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
|
||||
mask, dest_idx)
|
||||
if d_expert.dtype != ctx.dtype:
|
||||
d_expert = d_expert.to(ctx.dtype)
|
||||
|
||||
return d_expert, d_logits, None, None, None
|
||||
|
||||
|
||||
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and use_kernel:
|
||||
if MOE_KERNEL is None:
|
||||
load_moe()
|
||||
return MOE_KERNEL.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
||||
|
||||
class MoeInGradScaler(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient back by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.ep_size = ep_size
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.ep_size != 1:
|
||||
grad = grad * ctx.ep_size
|
||||
return grad, None
|
||||
|
||||
|
||||
class MoeOutGradScaler(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||||
ctx.ep_size = ep_size
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.ep_size != 1:
|
||||
grad = grad / ctx.ep_size
|
||||
return grad, None
|
|
@ -0,0 +1,274 @@
|
|||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
StateDictSharder,
|
||||
gather_distributed_param,
|
||||
get_model_base_filenames,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
save_config_file,
|
||||
save_state_dict_shards,
|
||||
)
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
|
||||
|
||||
|
||||
class MoeCheckpintIO(HybridParallelCheckpointIO):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
) -> None:
|
||||
assert zero_stage in [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
|
||||
super().__init__(dp_group, pp_group, tp_group, zero_stage)
|
||||
self.parallel = MOE_MANAGER.parallel
|
||||
|
||||
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
||||
"""
|
||||
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
||||
"""
|
||||
for name, param in state_dict.items():
|
||||
if ".experts." in name:
|
||||
if name in dict(model.named_parameters()):
|
||||
model_param = dict(model.named_parameters())[name]
|
||||
if is_moe_tensor(model_param):
|
||||
ep_rank = get_ep_rank(model_param)
|
||||
ep_size = get_ep_size(model_param)
|
||||
expert_num = param.shape[0] // ep_size
|
||||
assert param.shape[0] % ep_size == 0
|
||||
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
|
||||
state_dict[name] = param
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def _model_sharder(
|
||||
self,
|
||||
state_dict: nn.Module,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
|
||||
for name, param in state_dict.items():
|
||||
if param is None:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
|
||||
state_dict = torch.load(checkpoint)
|
||||
state_dict = self.pre_load_model(model, state_dict)
|
||||
model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
strict = False
|
||||
|
||||
# Load params & buffers to model.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
return
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
state_dict = self.pre_load_model(model, state_dict)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(
|
||||
model,
|
||||
state_dict,
|
||||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True,
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
for name, _ in model.named_parameters():
|
||||
_load(name)
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def pre_save_model(self, model: nn.Module) -> dict:
|
||||
state_dict = model.state_dict()
|
||||
for name, param in model.named_parameters():
|
||||
if ".experts." in name and is_moe_tensor(param):
|
||||
ep_group = get_ep_group(param)
|
||||
ep_rank = get_ep_rank(param)
|
||||
ep_size = get_ep_size(param)
|
||||
dp_rank = get_dp_rank(param)
|
||||
if dp_rank == 0:
|
||||
param = param.data.cuda()
|
||||
all_param = [deepcopy(param) for _ in range(ep_size)]
|
||||
# gather param from every ep rank
|
||||
dist.all_gather(all_param, param, group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
state_dict[name] = all_param.cpu()
|
||||
if self.pp_size > 1:
|
||||
if self.dp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
dist.all_gather_object(out, state_dict, group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
new_state_dict = {}
|
||||
for o in out:
|
||||
new_state_dict.update(o)
|
||||
state_dict = new_state_dict
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def save_unsharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_safetensors: bool,
|
||||
):
|
||||
state_dict = self.pre_save_model(model)
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(state_dict, checkpoint)
|
||||
dist.barrier()
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
|
||||
- Multiple files that store state tensors of models.
|
||||
The filenames are in the form of "pytorch_model.<prefix>-000XX.bin"
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a directory path.
|
||||
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
|
||||
prefix (str, optional): Perfix of file to save. Defaults to None.
|
||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict = self.pre_save_model(model)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
|
||||
|
||||
# Devices along the same dp_group share the same copies of model.
|
||||
# So only let the device with dp_rank == 0 save the model.
|
||||
if self.dp_rank != 0:
|
||||
return
|
||||
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
dist.barrier()
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
# ========================================================
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,156 @@
|
|||
import math
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
|
||||
|
||||
class MLPExperts(nn.Module):
|
||||
"""
|
||||
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
|
||||
|
||||
Args:
|
||||
num_experts (int): The number of experts
|
||||
hidden_size (int): The hidden size of MLP
|
||||
intermediate_size (int): The intermediate size of MLP
|
||||
expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
|
||||
activation (optional): The activation function of MLP
|
||||
drop_rate (float, optional): The drop rate of MLP
|
||||
gated (bool, optional): Whether to use gated MLP
|
||||
use_kernel (bool, optional): Whether to use kernel optimization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
expert_parallel: Optional[str] = None,
|
||||
activation: Optional[Callable] = None,
|
||||
drop_rate: Optional[float] = 0,
|
||||
gated: Optional[bool] = False,
|
||||
use_kernel: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert expert_parallel in ["EP", "TP", None]
|
||||
self.expert_parallel = expert_parallel
|
||||
self.num_total_experts = num_experts
|
||||
self.gated = gated
|
||||
self.use_kernel = use_kernel
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
# get expert parallel info
|
||||
if expert_parallel is not None:
|
||||
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
|
||||
num_experts, use_tp=True if expert_parallel == "TP" else False)
|
||||
# get settings for different parallel
|
||||
self.ep_size = get_ep_size(self)
|
||||
if expert_parallel == "TP":
|
||||
intermediate_size = intermediate_size // self.ep_size
|
||||
num_experts = self.num_total_experts
|
||||
else:
|
||||
num_experts = self.num_local_experts
|
||||
else:
|
||||
self.num_local_experts = self.num_total_experts
|
||||
self.ep_size = 1
|
||||
|
||||
if gated:
|
||||
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
|
||||
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
else:
|
||||
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
|
||||
|
||||
self.act_name = activation
|
||||
self.act = get_activation(activation)
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if expert_parallel is not None:
|
||||
for param in self.parameters():
|
||||
set_moe_tensor_info(param, self.moe_info)
|
||||
|
||||
# init param
|
||||
self.reset_parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def reset_parameters(self):
|
||||
# expert param should be different
|
||||
if self.expert_parallel is not None:
|
||||
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True)
|
||||
else:
|
||||
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
|
||||
with seed_ctx:
|
||||
if self.gated:
|
||||
torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
|
||||
torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
|
||||
else:
|
||||
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
|
||||
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
param_slice: Tuple[slice] = (slice(None),),
|
||||
use_sparse: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
forward: hidden_size --> intermediate_size --> hidden_size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||
"""
|
||||
x = MoeInGradScaler.apply(x, self.ep_size)
|
||||
|
||||
e = x.size(1)
|
||||
h = x.size(-1)
|
||||
|
||||
x = x.transpose(0, 1)
|
||||
inshape = x.shape
|
||||
x = x.reshape(e, -1, h)
|
||||
|
||||
if self.use_kernel and use_sparse:
|
||||
seq_len = x.shape[1]
|
||||
with torch.no_grad():
|
||||
mask = x[:, :, 0] != 0.0
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
x_list = []
|
||||
for i in range(e):
|
||||
x_list.append(x[i, :mask[i]])
|
||||
x = x_list
|
||||
|
||||
if self.gated:
|
||||
x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
|
||||
x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
|
||||
if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
|
||||
x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
|
||||
else:
|
||||
x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
|
||||
else:
|
||||
x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
|
||||
x = [self.act(x[i]) for i in range(e)]
|
||||
x = [self.drop(x[i]) for i in range(e)]
|
||||
x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
|
||||
|
||||
if self.use_kernel and use_sparse:
|
||||
for i in range(e):
|
||||
x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
|
||||
|
||||
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
||||
x = x.reshape(inshape)
|
||||
x = x.transpose(0, 1).contiguous()
|
||||
x = MoeOutGradScaler.apply(x, self.ep_size)
|
||||
return x
|
|
@ -0,0 +1,361 @@
|
|||
import dataclasses
|
||||
import math
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.load_balance import LoadBalancer
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.routers import MoeRouter, get_router_cls
|
||||
from colossalai.moe.utils import get_noise_generator
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
|
||||
|
||||
|
||||
class SparseMLP(nn.Module):
|
||||
"""A class for users to create MoE modules in their models.
|
||||
|
||||
Args:
|
||||
dim_model (int): Hidden dimension of training model
|
||||
num_experts (int): The number experts
|
||||
top_k (int, optional): The number of experts for dispatchment of each token
|
||||
capacity_factor_train (float, optional): Capacity factor in routing during training
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
|
||||
'Jitter' can be found in `Switch Transformer paper`_.
|
||||
'Gaussian' can be found in `ViT-MoE paper`_.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
|
||||
More information can be found in `Microsoft paper`_.
|
||||
residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
|
||||
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
|
||||
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
|
||||
expert_args (optional): The args of expert when no instance is given
|
||||
|
||||
.. _Switch Transformer paper:
|
||||
https://arxiv.org/abs/2101.03961
|
||||
.. _ViT-MoE paper:
|
||||
https://arxiv.org/abs/2106.05974
|
||||
.. _Microsoft paper:
|
||||
https://arxiv.org/abs/2201.05596
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
router_top_k: int = 1,
|
||||
router_capacity_factor_train: Optional[float] = 1.25,
|
||||
router_capacity_factor_eval: Optional[float] = 2.0,
|
||||
router_min_capacity: Optional[int] = 4,
|
||||
router_noisy_policy: Optional[str] = None,
|
||||
router_drop_tks: Optional[bool] = True,
|
||||
mlp_activation: Optional[str] = None,
|
||||
mlp_gated: Optional[bool] = False,
|
||||
enable_load_balance: Optional[bool] = False,
|
||||
load_balance_tolerance: Optional[float] = 0.1,
|
||||
load_balance_beam_width: Optional[int] = 8,
|
||||
load_balance_group_swap_factor: Optional[float] = 0.4,
|
||||
enable_kernel: Optional[bool] = False,
|
||||
enable_comm_overlap: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_experts = num_experts
|
||||
self.gated = mlp_gated
|
||||
self.enable_kernel = enable_kernel
|
||||
self.enable_comm_overlap = enable_comm_overlap
|
||||
self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||
|
||||
# moe router
|
||||
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
||||
router_cls = get_router_cls(router_top_k)
|
||||
self.topk = router_top_k
|
||||
self.router: MoeRouter = router_cls(
|
||||
capacity_factor_train=router_capacity_factor_train,
|
||||
capacity_factor_eval=router_capacity_factor_eval,
|
||||
min_capacity=router_min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=router_drop_tks,
|
||||
)
|
||||
|
||||
# gate
|
||||
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
|
||||
|
||||
# moe experts
|
||||
self.experts = MLPExperts(
|
||||
num_experts=self.num_experts,
|
||||
expert_parallel=self.expert_parallel,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
activation=mlp_activation,
|
||||
gated=mlp_gated,
|
||||
use_kernel=self.enable_kernel,
|
||||
)
|
||||
|
||||
# get parallel settings
|
||||
if self.expert_parallel is not None:
|
||||
self.ep_group = get_ep_group(self.experts)
|
||||
self.ep_size = get_ep_size(self.experts)
|
||||
self.dp_group = get_dp_group(self.experts)
|
||||
else:
|
||||
self.ep_group = None
|
||||
self.dp_group = None
|
||||
self.num_local_experts = self.experts.num_local_experts
|
||||
|
||||
# load balance
|
||||
self.enable_load_balance = enable_load_balance
|
||||
if self.enable_load_balance == True:
|
||||
self.load_balancer = LoadBalancer(
|
||||
experts=self.experts,
|
||||
gate=self.gate_weight,
|
||||
local_expert_num=self.num_local_experts,
|
||||
expert_num=self.num_experts,
|
||||
ep_group=self.ep_group,
|
||||
dp_group=self.dp_group,
|
||||
tolerance=load_balance_tolerance,
|
||||
beam_width=load_balance_beam_width,
|
||||
group_swap_factor=load_balance_group_swap_factor,
|
||||
)
|
||||
|
||||
# init param
|
||||
self.reset_parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size)
|
||||
"""
|
||||
# reshape the input tokens
|
||||
tokens = inputs.reshape(-1, self.hidden_size)
|
||||
|
||||
# the data type of the inputs in the gating should be fp32
|
||||
fp32_input = tokens.to(torch.float)
|
||||
fp32_weight = self.gate_weight.to(torch.float)
|
||||
gate_output = F.linear(fp32_input, fp32_weight)
|
||||
|
||||
# update expert load
|
||||
if self.enable_load_balance == True:
|
||||
with torch.no_grad():
|
||||
# TODO: optimize computation
|
||||
expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1]
|
||||
# TODO: bincount introduces synchronize, fix it
|
||||
expert_load = torch.bincount(expert_load.view(-1))
|
||||
self.load_balancer.update_load(expert_load)
|
||||
|
||||
# the result from the router
|
||||
route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
|
||||
|
||||
# dispatch_data: (num_experts, capacity, hidden_size)
|
||||
if self.enable_kernel:
|
||||
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
|
||||
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size)
|
||||
else:
|
||||
sec_mask_f = route_result_list[1].type_as(inputs)
|
||||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||
if self.expert_parallel == "EP":
|
||||
expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel == "TP":
|
||||
expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel is None:
|
||||
expert_output = self._local_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
|
||||
"Please use Experts build function.")
|
||||
|
||||
if self.enable_kernel:
|
||||
expert_output = expert_output.reshape(-1, self.hidden_size)
|
||||
ans = MoeCombine.apply(expert_output, *route_result_list)
|
||||
else:
|
||||
combine_weights = route_result_list[0].type_as(inputs)
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
ans = ans.reshape(inputs.shape)
|
||||
return ans
|
||||
|
||||
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
||||
expert_in = expert_in.unsqueeze(0)
|
||||
expert_out = self.experts(expert_in)
|
||||
return expert_out
|
||||
|
||||
def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Expert Parallel
|
||||
|
||||
Args:
|
||||
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (num_experts, capacity, hidden_size)
|
||||
"""
|
||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
|
||||
return expert_output
|
||||
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Capsule:
|
||||
data: torch.Tensor
|
||||
handle: Any = None
|
||||
|
||||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
|
||||
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
||||
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
dispatch_data = dispatch_data.reshape(*input_shape)
|
||||
chunk_data = torch.split(dispatch_data, chunk_size, dim=2)
|
||||
output = torch.empty_like(dispatch_data)
|
||||
|
||||
offset = 0
|
||||
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
|
||||
|
||||
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
||||
if expert_out is not None:
|
||||
expert_out.handle.wait()
|
||||
output[:, :, offset:offset + chunk_size, :] = expert_out.data
|
||||
offset += chunk_size
|
||||
expert_out = None
|
||||
|
||||
# all2all last output
|
||||
if _expert_out is not None:
|
||||
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
|
||||
_expert_out = None
|
||||
|
||||
# all2all next input
|
||||
if 0 <= i < NUM_CHUNK:
|
||||
_expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True))
|
||||
|
||||
# compute
|
||||
if expert_in is not None:
|
||||
expert_in.handle.wait()
|
||||
_expert_out = Capsule(data=self.experts(expert_in.data), handle=None)
|
||||
expert_in = None
|
||||
|
||||
if _expert_in is not None:
|
||||
expert_in = _expert_in
|
||||
_expert_in = None
|
||||
|
||||
return output
|
||||
|
||||
def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
without overlap:
|
||||
| C |
|
||||
| A | | R |
|
||||
|
||||
with overlap:
|
||||
| C1 || C2 || C3 || C4 |
|
||||
| A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 |
|
||||
|
||||
where C is computation, A is all gather, R is reduce scatter.
|
||||
|
||||
Args:
|
||||
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (num_experts, capacity, hidden_size)
|
||||
"""
|
||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||
expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0]
|
||||
expert_out = self.experts(expert_in)
|
||||
expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]
|
||||
return expert_out
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Capsule:
|
||||
data: torch.Tensor
|
||||
handle: Any
|
||||
indices: Tuple
|
||||
|
||||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert (dispatch_data.shape[0] % NUM_CHUNK == 0
|
||||
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
||||
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
||||
output = torch.empty_like(dispatch_data)
|
||||
|
||||
def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]:
|
||||
return (slice(idx * chunk_size, (idx + 1) * chunk_size),)
|
||||
|
||||
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
|
||||
|
||||
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
||||
if expert_out is not None:
|
||||
expert_out.handle.wait()
|
||||
output[expert_out.indices] = expert_out.data
|
||||
expert_out = None
|
||||
|
||||
# reduce scatter last output
|
||||
if _expert_out is not None:
|
||||
expert_out = Capsule(
|
||||
*ReduceScatter.apply(_expert_out.data, self.ep_group, True),
|
||||
indices=_expert_out.indices,
|
||||
)
|
||||
_expert_out = None
|
||||
|
||||
# all gather next input
|
||||
if 0 <= i < NUM_CHUNK:
|
||||
_expert_in = Capsule(
|
||||
*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True),
|
||||
indices=get_chunk_slice(i, chunk_size),
|
||||
)
|
||||
|
||||
# compute
|
||||
if expert_in is not None:
|
||||
expert_in.handle.wait()
|
||||
_expert_out = Capsule(
|
||||
self.experts(expert_in.data, expert_in.indices),
|
||||
handle=None,
|
||||
indices=expert_in.indices,
|
||||
)
|
||||
expert_in = None
|
||||
|
||||
if _expert_in is not None:
|
||||
expert_in = _expert_in
|
||||
_expert_in = None
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def apply_load_balance(model: nn.Module, optim: Any) -> None:
|
||||
"""
|
||||
apply load balance to every experts in the model
|
||||
"""
|
||||
|
||||
def _apply_recursive(module: nn.Module):
|
||||
for _, sub_module in module.named_children():
|
||||
if isinstance(sub_module, SparseMLP):
|
||||
if sub_module.enable_load_balance == True:
|
||||
sub_module.load_balancer.balance_load(optim)
|
||||
_apply_recursive(sub_module)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
_apply_recursive(model)
|
||||
torch.cuda.empty_cache()
|
|
@ -0,0 +1,442 @@
|
|||
from copy import deepcopy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
class LoadBalancer:
|
||||
def __init__(
|
||||
self,
|
||||
experts: MLPExperts,
|
||||
gate: nn.Parameter,
|
||||
local_expert_num: int,
|
||||
expert_num: int,
|
||||
ep_group: ProcessGroup,
|
||||
dp_group: ProcessGroup,
|
||||
tolerance: Optional[float] = 0.1,
|
||||
beam_width: Optional[int] = 8,
|
||||
group_swap_factor: Optional[float] = 0.4,
|
||||
) -> None:
|
||||
self.experts: MLPExperts = experts
|
||||
self.gate: nn.Parameter = gate
|
||||
self.moe_ep_group: ProcessGroup = ep_group
|
||||
self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks
|
||||
self.moe_dp_group: ProcessGroup = dp_group
|
||||
self.tolerance = tolerance
|
||||
self.beam_width = beam_width
|
||||
self.group_swap_factor = group_swap_factor
|
||||
self.local_expert_num = local_expert_num
|
||||
self.expert_num = expert_num
|
||||
self.local_load = None
|
||||
# TODO: use a global process group mesh
|
||||
pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size
|
||||
global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size)
|
||||
self.global_dp_group = global_dp_group.get_group_along_axis(1)
|
||||
self.global_dp_rank = dist.get_rank(self.global_dp_group)
|
||||
self.global_dp_size = dist.get_world_size(self.global_dp_group)
|
||||
|
||||
def _clear_load(self) -> None:
|
||||
self.local_load = None
|
||||
|
||||
def _sync_load(self) -> Tensor:
|
||||
new_load = self.local_load.clone().detach()
|
||||
# all reduce load between ep group
|
||||
dist.all_reduce(new_load, group=self.moe_ep_group)
|
||||
# all reduce load between dp group
|
||||
dist.all_reduce(new_load, group=self.moe_dp_group)
|
||||
return new_load
|
||||
|
||||
@staticmethod
|
||||
def _get_diff_from_avg(data: List, group: int, avg: float) -> float:
|
||||
return abs(sum(data[group]) / len(data[group]) - avg)
|
||||
|
||||
@staticmethod
|
||||
def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None:
|
||||
data[group_i][index_i], data[group_j][index_j] = (
|
||||
data[group_j][index_j],
|
||||
data[group_i][index_i],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_data(data: List) -> List:
|
||||
max_value = max(max(sublist) for sublist in data)
|
||||
data = [[i / max_value for i in sublist] for sublist in data]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _get_swap_loss(
|
||||
group_swap_factor: float,
|
||||
swap_list: List,
|
||||
group_i: int,
|
||||
index_i: int,
|
||||
group_j: int,
|
||||
index_j: int,
|
||||
) -> float:
|
||||
"""
|
||||
Get swap loss. The swap loss is used to avoid the situation that
|
||||
the same index is swapped twice and the same group is swapped for multiple times.
|
||||
"""
|
||||
swap_loss = 0
|
||||
for swap in swap_list:
|
||||
for group_id, index_id in zip([group_i, group_j], [index_i, index_j]):
|
||||
# the group has been swapped
|
||||
if group_id in [swap[0], swap[2]]:
|
||||
# the index has been swapped
|
||||
# we want to avoid the situation that the same index is swapped twice
|
||||
if index_id in [swap[1], swap[3]]:
|
||||
swap_loss += 1e5
|
||||
# the index has not been swapped
|
||||
# this is acceptable but as less as possible
|
||||
else:
|
||||
swap_loss += group_swap_factor
|
||||
return swap_loss
|
||||
|
||||
@staticmethod
|
||||
def _check_convergence(data: List, avg: float, tolerance: float):
|
||||
"""
|
||||
Check whether the data is converged after swap.
|
||||
"""
|
||||
for sublist in data:
|
||||
if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _beam_search(
|
||||
self,
|
||||
inputs: Tuple[List, float, List],
|
||||
beam_width: int,
|
||||
avg: float,
|
||||
group_swap_factor: float,
|
||||
) -> List:
|
||||
"""
|
||||
Beam search for the best swap combination.
|
||||
Specifically, we swap two elements from two groups and calculate the score.
|
||||
The score is the difference between the origin group sum and the new group sum.
|
||||
The larger the score, the better the swap combination.
|
||||
|
||||
Args:
|
||||
inputs (Tuple): (data, origin_score, swap_list)
|
||||
beam_width (int): beam width for beam search
|
||||
avg (float): average value of the data
|
||||
group_swap_factor (float): group loss for group swap loss
|
||||
|
||||
Returns:
|
||||
List: results list
|
||||
"""
|
||||
data, origin_score, swap_list = inputs
|
||||
results = []
|
||||
group_num = len(data)
|
||||
group_size = len(data[0])
|
||||
origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)]
|
||||
|
||||
for group_num_i in range(group_num):
|
||||
for group_size_i in range(group_size):
|
||||
for group_num_j in range(group_num_i + 1, group_num):
|
||||
for group_size_j in range(group_size):
|
||||
new_data = deepcopy(data)
|
||||
# calculate origin group sum
|
||||
origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j]
|
||||
# swap data
|
||||
self._swap_data(
|
||||
new_data,
|
||||
group_num_i,
|
||||
group_size_i,
|
||||
group_num_j,
|
||||
group_size_j,
|
||||
)
|
||||
# calculate new group sum
|
||||
new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg(
|
||||
new_data, group_num_j, avg
|
||||
)
|
||||
# caculate score
|
||||
new_score = origin_diff - new_diff
|
||||
if new_score > 0:
|
||||
new_score = origin_score + new_score
|
||||
# get swap loss
|
||||
swap_loss = self._get_swap_loss(
|
||||
group_swap_factor,
|
||||
swap_list,
|
||||
group_num_i,
|
||||
group_size_i,
|
||||
group_num_j,
|
||||
group_size_j,
|
||||
)
|
||||
new_score = new_score - swap_loss
|
||||
# update swap list
|
||||
new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)]
|
||||
results.append((new_data, new_score, new_swap_list))
|
||||
# sort results
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
# select top k results
|
||||
results = results[:beam_width]
|
||||
return results
|
||||
|
||||
def _load_to_list(self, load: Tensor) -> List:
|
||||
load_len = len(load)
|
||||
assert load_len % self.local_expert_num == 0
|
||||
load_list = []
|
||||
tmp_list = []
|
||||
for i in range(len(load)):
|
||||
tmp_list.append(float(load[i]))
|
||||
if (i + 1) % self.local_expert_num == 0:
|
||||
load_list.append(tmp_list)
|
||||
tmp_list = []
|
||||
return load_list
|
||||
|
||||
def _search_balance(
|
||||
self,
|
||||
data: List,
|
||||
tolerance: Optional[float] = 0.1,
|
||||
beam_width: Optional[int] = 8,
|
||||
group_swap_factor: Optional[float] = 0.4,
|
||||
return_swapped_data: Optional[bool] = False,
|
||||
) -> Tuple[List, List]:
|
||||
"""
|
||||
Search for the best swap combination to balance the data within the specified tolerance.
|
||||
And return the balanced data and the swap list. The swap list is used to record the swap.
|
||||
The swap list is a list of tuples. Each tuple is a swap operation.
|
||||
|
||||
Args:
|
||||
data (List): expert load list.
|
||||
E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]]
|
||||
This means there are 4 devices and each devices has 2 experts.
|
||||
The value is the load of the expert.
|
||||
tolerance (float): tolerance for balance.
|
||||
beam_width (int): beam width for beam search.
|
||||
group_swap_factor (float): group swap factor for group swap loss.
|
||||
The bigger it is, the less times a group will be swapped.
|
||||
return_swapped_data (bool): whether to return the swapped data.
|
||||
|
||||
Returns:
|
||||
Tuple: (balanced data, swap list).
|
||||
The swap list is a list of tuples. Each tuple is a swap operation.
|
||||
E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means
|
||||
the first expert of the first device is swapped with the first expert
|
||||
of the second device.
|
||||
"""
|
||||
norm_data = self._normalize_data(data)
|
||||
avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data)
|
||||
results = [(norm_data, 0, [])]
|
||||
stop_flag = False
|
||||
|
||||
while stop_flag == False:
|
||||
new_results = []
|
||||
best_score = results[0][1]
|
||||
for i in range(len(results)):
|
||||
new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor))
|
||||
if len(new_results) == 0:
|
||||
stop_flag = True
|
||||
break
|
||||
new_results.sort(key=lambda x: x[1], reverse=True)
|
||||
new_best_score = new_results[0][1]
|
||||
if new_best_score == best_score:
|
||||
stop_flag = True
|
||||
break
|
||||
new_results = new_results[:beam_width]
|
||||
results = new_results
|
||||
for i in results:
|
||||
if self._check_convergence(results[0][0], avg, tolerance):
|
||||
stop_flag = True
|
||||
break
|
||||
|
||||
swap_list = results[0][2]
|
||||
if return_swapped_data:
|
||||
out = deepcopy(data)
|
||||
for swap in swap_list:
|
||||
self._swap_data(out, *swap)
|
||||
return out, swap_list
|
||||
else:
|
||||
return swap_list
|
||||
|
||||
@staticmethod
|
||||
def _swap_expert_single_tensor(
|
||||
weight: nn.Parameter,
|
||||
expert_idx: int,
|
||||
comm_group: ProcessGroup,
|
||||
send_first: bool,
|
||||
comm_rank: int,
|
||||
):
|
||||
# exchange weight
|
||||
local_weight = weight.data[expert_idx]
|
||||
new_weight = torch.empty_like(local_weight)
|
||||
if send_first:
|
||||
dist.send(local_weight, dst=comm_rank, group=comm_group)
|
||||
dist.recv(new_weight, src=comm_rank, group=comm_group)
|
||||
else:
|
||||
dist.recv(new_weight, src=comm_rank, group=comm_group)
|
||||
dist.send(local_weight, dst=comm_rank, group=comm_group)
|
||||
weight.data[expert_idx] = new_weight
|
||||
|
||||
def _swap_expert_param_and_optim(
|
||||
self,
|
||||
weight: nn.Parameter,
|
||||
expert_idx: int,
|
||||
comm_group: ProcessGroup,
|
||||
send_first: bool,
|
||||
comm_rank: int,
|
||||
optim: LowLevelZeroOptimizer,
|
||||
):
|
||||
# need to update master and working param if master param exists
|
||||
# else just update working param
|
||||
if weight in optim.optim.state:
|
||||
master_weight_ptr = None
|
||||
working_weight_ptr = weight
|
||||
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
|
||||
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
|
||||
else:
|
||||
master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
|
||||
working_weight_ptr = weight
|
||||
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
|
||||
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
|
||||
|
||||
# exchange weight
|
||||
self._swap_expert_single_tensor(
|
||||
working_weight_ptr,
|
||||
expert_idx,
|
||||
comm_group,
|
||||
send_first,
|
||||
comm_rank,
|
||||
)
|
||||
if master_weight_ptr is not None:
|
||||
# TODO: exchange master weight, skip for now
|
||||
# master weight is shared by dp group
|
||||
tmp = working_weight_ptr.view(-1).split(
|
||||
working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group)
|
||||
)[dist.get_rank(self.moe_dp_group)]
|
||||
master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype))
|
||||
# exchange optim
|
||||
self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank)
|
||||
self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank)
|
||||
|
||||
def _gather_global_dp_group(self, data: Tensor) -> Tensor:
|
||||
data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)]
|
||||
dist.all_gather(data_list, data, group=self.global_dp_group)
|
||||
data_list = torch.cat(data_list, dim=0)
|
||||
return data_list
|
||||
|
||||
def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None:
|
||||
"""
|
||||
Swap moe param and optim.
|
||||
We use different strategies to swap expert and gate.
|
||||
For expert, we exchange the param and optim of the expert by p2p.
|
||||
For gate, we all gather the gate choose the part we want.
|
||||
|
||||
Args:
|
||||
swap_list (List)
|
||||
optim (LowLevelZeroOptimizer)
|
||||
"""
|
||||
# get all experts weights
|
||||
local_rank = dist.get_rank(self.moe_ep_group)
|
||||
if self.experts.gated:
|
||||
weight_list = [self.experts.wi_up, self.experts.wi_gate]
|
||||
else:
|
||||
weight_list = [self.experts.wi]
|
||||
weight_list.append(self.experts.wo)
|
||||
|
||||
# gate optim should be obtained first
|
||||
gate_shape = self.gate.shape
|
||||
# get master weight and optim
|
||||
master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
|
||||
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
|
||||
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
|
||||
# gather
|
||||
global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape)
|
||||
global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape)
|
||||
global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape)
|
||||
assert (
|
||||
self.gate.shape
|
||||
== global_master_gate_weight.shape
|
||||
== global_gate_exp_avg.shape
|
||||
== global_gate_exp_avg_sq.shape
|
||||
)
|
||||
|
||||
for swap in swap_list:
|
||||
source_group, source_idx, target_group, target_idx = swap
|
||||
source_rank = self.moe_ep_ranks[source_group]
|
||||
target_rank = self.moe_ep_ranks[target_group]
|
||||
# exchange expert
|
||||
if local_rank in [source_group, target_group]:
|
||||
for weight in weight_list:
|
||||
if local_rank == source_group:
|
||||
self._swap_expert_param_and_optim(
|
||||
weight,
|
||||
source_idx,
|
||||
self.moe_ep_group,
|
||||
True,
|
||||
target_rank,
|
||||
optim,
|
||||
)
|
||||
elif local_rank == target_group:
|
||||
self._swap_expert_param_and_optim(
|
||||
weight,
|
||||
target_idx,
|
||||
self.moe_ep_group,
|
||||
False,
|
||||
source_rank,
|
||||
optim,
|
||||
)
|
||||
# exchange gate
|
||||
source_expert_pos = source_group * self.local_expert_num + source_idx
|
||||
target_expert_pos = target_group * self.local_expert_num + target_idx
|
||||
for gate in [
|
||||
self.gate,
|
||||
global_master_gate_weight,
|
||||
global_gate_exp_avg,
|
||||
global_gate_exp_avg_sq,
|
||||
]:
|
||||
origin_source = gate.data[source_expert_pos].clone().detach()
|
||||
origin_target = gate.data[target_expert_pos].clone().detach()
|
||||
gate.data[source_expert_pos], gate.data[target_expert_pos] = (
|
||||
origin_target,
|
||||
origin_source,
|
||||
)
|
||||
|
||||
# update gate
|
||||
global_master_gate_weight = global_master_gate_weight.view(-1).split(
|
||||
global_master_gate_weight.numel() // self.global_dp_size
|
||||
)[self.global_dp_rank]
|
||||
master_gate_weight.data.copy_(global_master_gate_weight)
|
||||
global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[
|
||||
self.global_dp_rank
|
||||
]
|
||||
gate_exp_avg.data.copy_(global_gate_exp_avg)
|
||||
global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(
|
||||
global_gate_exp_avg_sq.numel() // self.global_dp_size
|
||||
)[self.global_dp_rank]
|
||||
gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_load(self, load: Tensor) -> None:
|
||||
if len(load) != self.expert_num:
|
||||
padding_size = self.expert_num - len(load)
|
||||
padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device)
|
||||
load = torch.cat((load, padding), dim=0)
|
||||
if self.local_load is None:
|
||||
self.local_load = load
|
||||
else:
|
||||
self.local_load += load
|
||||
|
||||
@torch.no_grad()
|
||||
def balance_load(self, optim: LowLevelZeroOptimizer) -> None:
|
||||
# prepare load
|
||||
load = self._sync_load()
|
||||
load = self._load_to_list(load)
|
||||
# search balance
|
||||
swap_list = self._search_balance(load)
|
||||
if dist.get_rank() == 0:
|
||||
if len(swap_list) > 0:
|
||||
print(f"[Load Balance] Applying expert swap...")
|
||||
else:
|
||||
print(f"[Load Balance] Invalid swap, skip...")
|
||||
# swap expert and gate
|
||||
self._swap_moe_param(swap_list, optim)
|
||||
# clear load
|
||||
self._clear_load()
|
|
@ -1,11 +1,9 @@
|
|||
import torch.nn as nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.legacy.registry import LOSSES
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class MoeCrossEntropyLoss(_Loss):
|
||||
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
|
||||
|
||||
|
@ -45,11 +43,10 @@ class MoeCrossEntropyLoss(_Loss):
|
|||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
|
||||
"""
|
||||
main_loss = self.loss(*args)
|
||||
aux_loss = MOE_CONTEXT.get_loss()
|
||||
aux_loss = MOE_MANAGER.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class MoeLoss(_Loss):
|
||||
"""A wrapper class for any loss module to add with auxiliary loss.
|
||||
|
||||
|
@ -77,5 +74,5 @@ class MoeLoss(_Loss):
|
|||
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
|
||||
"""
|
||||
main_loss = self.loss_fn(*args, **kwargs)
|
||||
aux_loss = MOE_CONTEXT.get_loss()
|
||||
aux_loss = MOE_MANAGER.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
|
@ -0,0 +1,162 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.moe_tensor.api import get_moe_info
|
||||
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
|
||||
|
||||
class MoeManager(metaclass=SingletonMeta):
|
||||
"""MoE manager. This class manages different
|
||||
parallel groups in MoE context and MoE loss in training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.parallel = None
|
||||
self.seed = None
|
||||
self.mode = None
|
||||
self.use_ep_inside = None
|
||||
self.world_size = None
|
||||
self._parallel_info_dict = dict()
|
||||
|
||||
# router
|
||||
self.router_aux_loss = []
|
||||
self.router_z_loss = []
|
||||
|
||||
# fixed mode
|
||||
self.pp_size = None
|
||||
self.dp_size = None
|
||||
self.ep_size = None
|
||||
|
||||
# dynamic mode
|
||||
# Users may want to set maximum expert parallel size smaller than the world size
|
||||
# since very low bandwidth across nodes may constrain the performance of MoE
|
||||
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
|
||||
self.max_ep_size = None
|
||||
|
||||
self.has_setup = False
|
||||
|
||||
@property
|
||||
def parallel_info_dict(self):
|
||||
return self._parallel_info_dict
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self.has_setup
|
||||
|
||||
def setup(
|
||||
self,
|
||||
seed: int,
|
||||
parallel: str = None,
|
||||
mode: str = "dynamic",
|
||||
max_ep_size: int = 8,
|
||||
fixed_dp_size: int = 0,
|
||||
fixed_ep_size: int = 0,
|
||||
fixed_pp_size: int = 0,
|
||||
use_ep_inside: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Setup MoE distributed context.
|
||||
|
||||
Args:
|
||||
seed (int): Random seed. Defaults to 42.
|
||||
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
|
||||
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
|
||||
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
|
||||
In fixed mode, the ep size and dp size is fixed.
|
||||
In dynamic mode, the ep size and dp size will be changed according to num experts.
|
||||
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
|
||||
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
|
||||
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
|
||||
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
|
||||
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
|
||||
"""
|
||||
assert (not self.is_initialized), "MoE distributed context shouldn't be set up again"
|
||||
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
|
||||
|
||||
self.seed = seed + dist.get_rank()
|
||||
self.parallel = parallel
|
||||
self.use_ep_inside = use_ep_inside
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
# init by mode
|
||||
self.mode = mode
|
||||
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
|
||||
if self.mode == "dynamic":
|
||||
self.max_ep_size = min(max_ep_size, self.world_size)
|
||||
else:
|
||||
assert (fixed_dp_size > 0 and fixed_ep_size > 0
|
||||
and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0"
|
||||
assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int)
|
||||
and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int"
|
||||
self.ep_size = fixed_ep_size
|
||||
self.dp_size = fixed_dp_size
|
||||
self.pp_size = fixed_pp_size
|
||||
|
||||
self.has_setup = True
|
||||
|
||||
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
|
||||
"""Calculate the Data Parallel Group and Expert Parallel Group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_experts : int
|
||||
The number experts
|
||||
|
||||
Returns
|
||||
-------
|
||||
int, MoeParallelInfo
|
||||
number of local experts, the MoeParallelInfo of the current ep_size
|
||||
"""
|
||||
|
||||
if self.mode == "dynamic":
|
||||
gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater
|
||||
lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less
|
||||
assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number"
|
||||
" is not a multiple of ep size or vice versa.")
|
||||
dp_size = 1 if gt_flag else self.world_size // num_experts
|
||||
ep_size = min(self.world_size // dp_size, self.max_ep_size)
|
||||
dp_size = self.world_size // ep_size
|
||||
pp_size = 1
|
||||
else:
|
||||
dp_size = self.dp_size
|
||||
ep_size = self.ep_size
|
||||
pp_size = self.pp_size
|
||||
|
||||
# Calculate the number of experts for each GPU
|
||||
if use_tp:
|
||||
num_local_experts = num_experts
|
||||
else:
|
||||
if self.mode == "dynamic":
|
||||
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
|
||||
else:
|
||||
num_local_experts = num_experts // ep_size
|
||||
|
||||
if not (ep_size in self.parallel_info_dict):
|
||||
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
|
||||
if dist.get_rank() == 0:
|
||||
if self.use_ep_inside:
|
||||
print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
|
||||
else:
|
||||
print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
|
||||
|
||||
return num_local_experts, self.parallel_info_dict[ep_size]
|
||||
|
||||
def reset_loss(self):
|
||||
self.router_aux_loss, self.router_z_loss = [], []
|
||||
|
||||
def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
|
||||
self.router_aux_loss.append(aux_loss)
|
||||
self.router_z_loss.append(z_loss)
|
||||
|
||||
def get_loss(self):
|
||||
cur_loss = self.router_aux_loss, self.router_z_loss
|
||||
return cur_loss
|
||||
|
||||
def get_parallel(self):
|
||||
return self.parallel
|
||||
|
||||
|
||||
MOE_MANAGER = MoeManager()
|
|
@ -0,0 +1,419 @@
|
|||
import math
|
||||
from abc import ABC
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.moe._operation import moe_cumsum
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class MoeRouter(nn.Module, ABC):
|
||||
"""Base class for all MoE routers.
|
||||
Args:
|
||||
k_value (int): The value of top_k.
|
||||
capacity_factor_train (float): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float): Capacity factor in routing of evaluation.
|
||||
min_capacity (int): The minimum number of the capacity of each expert.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
use_kernel: bool = False):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
self._aux_loss = None
|
||||
self._z_loss = None
|
||||
self.use_kernel = use_kernel
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return int(capacity)
|
||||
|
||||
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
|
||||
"""Computes auxiliary load balancing loss as in Switch Transformer.
|
||||
|
||||
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
|
||||
implements the loss function presented in equations (4) - (6). It aims to
|
||||
penalize those cases where the routing between experts is unbalanced.
|
||||
|
||||
Args:
|
||||
router_probs: Probability assigned to each expert per token. Shape:
|
||||
<float32>[num_groups, tokens_per_group, num_experts].
|
||||
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
|
||||
indices identifying the top num_selected_experts for a given token.
|
||||
"""
|
||||
assert self._aux_loss is None
|
||||
if router_probs.dim() == expert_indices.dim() == 2:
|
||||
router_probs = router_probs.unsqueeze(0)
|
||||
expert_indices = expert_indices.unsqueeze(0)
|
||||
assert router_probs.dim() == expert_indices.dim() == 3, \
|
||||
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
||||
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||
expert_mask = F.one_hot(expert_indices, num_experts)
|
||||
# For a given token, determine if it was routed to a given expert.
|
||||
# Shape: [num_groups, tokens_per_group, num_experts]
|
||||
expert_mask = expert_mask.max(dim=-2)[0]
|
||||
|
||||
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
|
||||
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
|
||||
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
|
||||
self._aux_loss = aux_loss
|
||||
|
||||
def set_z_loss(self, router_logits: torch.Tensor):
|
||||
"""Compute router z-loss.
|
||||
|
||||
The router z-loss was introduced in Designing Effective Sparse Expert Models
|
||||
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
|
||||
small in an effort to improve stability.
|
||||
|
||||
Args:
|
||||
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
|
||||
"""
|
||||
assert self._z_loss is None
|
||||
if router_logits.dim() == 2:
|
||||
router_logits = router_logits.unsqueeze(0)
|
||||
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
|
||||
num_groups, tokens_per_group, _ = router_logits.shape
|
||||
log_z = torch.logsumexp(router_logits, dim=-1)
|
||||
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
|
||||
self._z_loss = z_loss
|
||||
|
||||
def pop_router_loss(self) -> torch.Tensor:
|
||||
assert self._aux_loss is not None
|
||||
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
|
||||
self._aux_loss = None
|
||||
self._z_loss = None
|
||||
|
||||
|
||||
class Top1Router(MoeRouter):
|
||||
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
|
||||
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
|
||||
function can be found in the paper about Switch Transformer of Google.
|
||||
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
||||
select_policy (str, optional): The policy about tokens selection.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0,
|
||||
device=get_current_device())).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
||||
Returns:
|
||||
1. use_kernel is False:
|
||||
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
2. use_kernel is True:
|
||||
...
|
||||
"""
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
num_experts = probs.size(-1)
|
||||
capacity = self.get_capacity(inputs.shape)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
# caculate router loss
|
||||
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
if self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return probs, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * probs.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
|
||||
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
|
||||
function can be found in the paper about ViT-MoE.
|
||||
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
||||
Returns:
|
||||
1. use_kernel is False:
|
||||
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
2. use_kernel is True:
|
||||
...
|
||||
"""
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
num_experts = probs.size(-1)
|
||||
capacity = self.get_capacity(inputs.shape)
|
||||
|
||||
top1_idx = torch.argmax(probs, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = (mask1 + mask2) # loss: [s, e]
|
||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||
|
||||
# caculate loss
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
return probs, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
# >>> original code
|
||||
# weight1 = mask1 * probs.type_as(inputs)
|
||||
# weight2 = mask2 * probs.type_as(inputs)
|
||||
# rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
# rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
# cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
# cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
# cb_weight = cb_weight1 + cb_weight2
|
||||
# sec_mask = cb_weight.bool()
|
||||
|
||||
weight1 = mask1 * probs.type_as(inputs)
|
||||
weight2 = mask2 * probs.type_as(inputs)
|
||||
|
||||
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
|
||||
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
|
||||
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
|
||||
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
|
||||
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
|
||||
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
|
||||
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
|
||||
|
||||
return cb_weight, sec_mask
|
||||
|
||||
|
||||
class TopKRouter(MoeRouter):
|
||||
"""Masked matmul router using tokens choose top-k experts assignment.
|
||||
|
||||
NOTE: this is modified from flaxformer.
|
||||
This router uses the same mechanism as in Switch Transformer
|
||||
(https://arxiv.org/abs/2101.03961) and V-MoE
|
||||
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
|
||||
sorted by router_probs and then routed to their choice of expert until the
|
||||
expert's expert_capacity is reached. There is no guarantee that each token is
|
||||
processed by an expert, or that each expert receives at least one token.
|
||||
|
||||
Attributes:
|
||||
num_selected_experts: Maximum number of experts to which each token is
|
||||
routed. Tokens may be routed to fewer experts if particular experts are
|
||||
oversubscribed / reach capacity.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_selected_experts: int,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
|
||||
drop_tks)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
router_probs: torch.Tensor,
|
||||
expert_capacity: int,
|
||||
) -> Tuple:
|
||||
"""Computes masks for the top-k experts per token.
|
||||
|
||||
Args:
|
||||
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
|
||||
probabilities used to determine the routing of tokens to the experts.
|
||||
|
||||
Returns:
|
||||
Dispatch and combine arrays for routing with masked matmuls.
|
||||
"""
|
||||
# TODO: add parallel group
|
||||
num_groups, _, num_experts = router_probs.shape
|
||||
|
||||
# Top-k router probability and corresponding expert indices for each token.
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts].
|
||||
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
|
||||
|
||||
self.set_aux_loss(router_probs, expert_index, num_experts)
|
||||
self.pop_router_loss()
|
||||
|
||||
# Make num_selected_experts the leading axis to ensure that top-1 choices
|
||||
# have priority over top-2 choices, which have priority over top-3 choices,
|
||||
# etc.
|
||||
expert_index = torch.transpose(expert_index, 1, 2)
|
||||
# Shape: [num_groups, num_selected_experts * tokens_per_group]
|
||||
expert_index = expert_index.reshape(num_groups, -1)
|
||||
|
||||
# Create mask out of indices.
|
||||
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
|
||||
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
|
||||
|
||||
# Experts have a fixed capacity that we cannot exceed. A token's priority
|
||||
# within the expert's buffer is given by the masked, cumulative capacity of
|
||||
# its target expert.
|
||||
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
|
||||
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
|
||||
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
|
||||
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||
token_priority = torch.transpose(token_priority, 1, 2)
|
||||
# For each token, across all selected experts, select the only non-negative
|
||||
# (unmasked) priority. Now, for group G routing to expert E, token T has
|
||||
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
|
||||
# is its targeted expert.
|
||||
# Shape: [num_groups, tokens_per_group, num_experts].
|
||||
token_priority = torch.max(token_priority, dim=2)[0]
|
||||
|
||||
# Token T can only be routed to expert E if its priority is positive and
|
||||
# less than the expert capacity. One-hot matrix will ignore indices outside
|
||||
# the range [0, expert_capacity).
|
||||
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
|
||||
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
|
||||
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
|
||||
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
|
||||
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
|
||||
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
|
||||
|
||||
# The combine array will be used for combining expert outputs, scaled by the
|
||||
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
|
||||
# expert_capacity].
|
||||
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
|
||||
|
||||
return combine_array, dispatch_mask
|
||||
|
||||
|
||||
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
|
||||
if not grouped:
|
||||
if top_k == 1:
|
||||
return Top1Router
|
||||
elif top_k == 2:
|
||||
return Top2Router
|
||||
else:
|
||||
raise NotImplementedError("top_k > 2 is not supported yet")
|
||||
else:
|
||||
return TopKRouter
|
|
@ -0,0 +1,177 @@
|
|||
import contextlib
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
||||
def half(self, memory_format=None):
|
||||
return self.data.clone()
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
|
||||
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
|
||||
`E = the number of experts`.
|
||||
|
||||
Args:
|
||||
num_experts (int): The number of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
return inputs + noisy
|
||||
|
||||
|
||||
class UniformNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
copied from mesh tensorflow:
|
||||
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
|
||||
Makes models more resilient to rounding errors introduced by bfloat16.
|
||||
This seems particularly important for logits.
|
||||
|
||||
Args:
|
||||
eps (float, optional): Epsilon in generator, defaults 1e-2.
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-2):
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(1.0 - eps, device=get_current_device()),
|
||||
high=torch.tensor(1.0 + eps, device=get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.uniform(inputs.shape)
|
||||
return inputs * noisy
|
||||
|
||||
|
||||
def autocast_softmax(logit: torch.Tensor, dim: int):
|
||||
return F.softmax(logit, dim=dim, detype=torch.float32)
|
||||
|
||||
|
||||
def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
|
||||
if noise_type is None:
|
||||
return None
|
||||
elif noise_type == "Jitter":
|
||||
noisy_func = UniformNoiseGenerator()
|
||||
elif noise_type == "Gaussian":
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported input noisy policy")
|
||||
return noisy_func
|
||||
|
||||
|
||||
def get_activation(act: str) -> Callable:
|
||||
if act is None or act == "relu":
|
||||
return torch.nn.ReLU()
|
||||
elif act == "gelu":
|
||||
return torch.nn.GELU()
|
||||
elif act == "swiglu":
|
||||
return SwiGLU
|
||||
else:
|
||||
raise NotImplementedError("Unsupported activation function")
|
||||
|
||||
|
||||
def SwiGLU(x):
|
||||
"""Gated linear unit activation function.
|
||||
Args:
|
||||
x : input array
|
||||
axis: the axis along which the split should be computed (default: -1)
|
||||
"""
|
||||
size = x.shape[-1]
|
||||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||||
x1, x2 = torch.split(x, size // 2, -1)
|
||||
return x1 * (x2 * torch.sigmoid(x2))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def skip_init():
|
||||
"""
|
||||
skip param random init
|
||||
"""
|
||||
|
||||
def _skip_init(*args, **kwargs):
|
||||
pass
|
||||
|
||||
init_func = {
|
||||
"constant_": torch.nn.init.constant_,
|
||||
"uniform_": torch.nn.init.uniform_,
|
||||
"normal_": torch.nn.init.normal_,
|
||||
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
|
||||
"kaiming_normal_": torch.nn.init.kaiming_normal_,
|
||||
"xavier_normal_": torch.nn.init.xavier_normal_,
|
||||
"xavier_uniform_": torch.nn.init.xavier_uniform_,
|
||||
"trunc_normal_": torch.nn.init.trunc_normal_,
|
||||
}
|
||||
|
||||
for method_name, original_init in init_func.items():
|
||||
setattr(torch.nn.init, method_name, _skip_init)
|
||||
|
||||
yield
|
||||
|
||||
for method_name, original_init in init_func.items():
|
||||
setattr(torch.nn.init, method_name, original_init)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
|
||||
"""Returns a parameter dictionary, the key of which is the expert parallel
|
||||
size of every parameter. Since the parameters in data parallelism is replicated
|
||||
in each GPU, we set their ep_size to 1.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
|
||||
"""
|
||||
epsize_param_dict = dict()
|
||||
for param in model.parameters():
|
||||
if not is_moe_tensor(param):
|
||||
ep_size = 1 # set ep_size to 1 for dp parameters
|
||||
else:
|
||||
ep_size = get_ep_size(param)
|
||||
if ep_size not in epsize_param_dict:
|
||||
epsize_param_dict[ep_size] = []
|
||||
epsize_param_dict[ep_size].append(param)
|
||||
|
||||
return epsize_param_dict
|
||||
|
||||
|
||||
def sync_moe_model_param(model: nn.Module):
|
||||
"""Make sure model parameters are consistent in MoE parallel context.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
"""
|
||||
param_dict = get_moe_epsize_param_dict(model)
|
||||
|
||||
# synchronize the parameters whose dp_group is the whole world
|
||||
if 1 in param_dict:
|
||||
for param in param_dict[1]:
|
||||
dist.broadcast(param, src=0)
|
||||
|
||||
for ep_size in param_dict:
|
||||
# When ep_size = world_size, communication is not needed
|
||||
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
|
||||
for param in param_dict[ep_size]:
|
||||
src_rank = get_dp_group_ranks(param)[0]
|
||||
dist.broadcast(param, src=src_rank, group=get_dp_group(param))
|
||||
|
||||
|
||||
def set_moe_args(config: Any, args: dict):
|
||||
for k, v in args.items():
|
||||
setattr(config, k, v)
|
|
@ -1,2 +1 @@
|
|||
# from .moe import *
|
||||
from .utils import *
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
from .checkpoint import load_moe_model, save_moe_model
|
||||
from .experts import Experts, FFNExperts, TPExperts
|
||||
from .layers import MoeLayer, MoeModule
|
||||
from .routers import MoeRouter, Top1Router, Top2Router
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
|
||||
|
||||
__all__ = [
|
||||
"Experts",
|
||||
"FFNExperts",
|
||||
"TPExperts",
|
||||
"Top1Router",
|
||||
"Top2Router",
|
||||
"MoeLayer",
|
||||
"NormalNoiseGenerator",
|
||||
"UniformNoiseGenerator",
|
||||
"build_ffn_experts",
|
||||
"MoeModule",
|
||||
"MoeRouter",
|
||||
"save_moe_model",
|
||||
"load_moe_model",
|
||||
]
|
|
@ -1,171 +0,0 @@
|
|||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
COL_MOE_KERNEL_FLAG = False
|
||||
|
||||
try:
|
||||
from colossalai._C import moe
|
||||
except:
|
||||
moe = None
|
||||
|
||||
|
||||
def build_moe_if_not_prebuilt():
|
||||
# load moe kernel during runtime if not pre-built
|
||||
global moe
|
||||
if moe is None:
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
|
||||
moe = MOEBuilder().load()
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
global moe
|
||||
|
||||
if moe is None:
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
|
||||
moe = MOEBuilder().load()
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0)
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0)
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||
operation in torch.distributed.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
if dist.get_world_size(group) == 1:
|
||||
return inputs
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class MoeDispatch(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, tokens, mask, dest_idx, ec):
|
||||
s = tokens.size(0)
|
||||
h = tokens.size(1)
|
||||
|
||||
# load moe kernel during runtime if not pre-built
|
||||
build_moe_if_not_prebuilt()
|
||||
|
||||
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
|
||||
ctx.save_for_backward(mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.h = h
|
||||
ctx.ec = ec
|
||||
|
||||
return expert_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
mask, dest_idx = ctx.saved_tensors
|
||||
d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
return d_tokens, None, None, None
|
||||
|
||||
|
||||
class MoeCombine(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||||
assert logits.dtype == torch.float32
|
||||
|
||||
s = logits.size(0)
|
||||
e = logits.size(1)
|
||||
c = ec // e
|
||||
h = expert_tokens.size(-1)
|
||||
|
||||
# load moe kernel during runtime if not pre-built
|
||||
build_moe_if_not_prebuilt()
|
||||
|
||||
fp16_flag = expert_tokens.dtype == torch.float16
|
||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||
|
||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.e = e
|
||||
ctx.c = c
|
||||
ctx.h = h
|
||||
ctx.fp16_flag = fp16_flag
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, tokens_grad):
|
||||
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||||
|
||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad
|
||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
|
||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||
|
||||
return d_expert, d_logits, None, None, None
|
||||
|
||||
|
||||
def moe_cumsum(inputs: Tensor):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and COL_MOE_KERNEL_FLAG:
|
||||
# load moe kernel during runtime if not pre-built
|
||||
build_moe_if_not_prebuilt()
|
||||
return moe.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
|
@ -1,40 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from .experts import MoeExperts
|
||||
|
||||
|
||||
def save_moe_model(model: nn.Module, save_path: str):
|
||||
state_dict = model.state_dict()
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(state_dict, save_path)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load_moe_model(model: nn.Module, load_path: str):
|
||||
state_dict = torch.load(load_path)
|
||||
|
||||
for prefix, module in model.named_modules():
|
||||
if prefix.endswith(".moe_layer.experts"):
|
||||
# this module should be an Experts instance
|
||||
assert isinstance(module, MoeExperts)
|
||||
|
||||
ep_rank = dist.get_rank(module.dist_info.ep_group)
|
||||
num_local = module.num_local_experts
|
||||
for i in range(num_local):
|
||||
expert_id = ep_rank * num_local + i
|
||||
for name, _ in module.experts[i].named_parameters():
|
||||
cur_key = f"{prefix}.experts.{i}.{name}"
|
||||
param_key = f"{prefix}.experts.{expert_id}.{name}"
|
||||
load_param = state_dict[param_key]
|
||||
state_dict[cur_key] = load_param
|
||||
|
||||
for name, _ in module.experts[0].named_parameters():
|
||||
pop_pre = f"{prefix}.experts."
|
||||
pop_suf = f".{name}"
|
||||
for i in range(num_local, module.num_total_experts):
|
||||
pop_key = f"{pop_pre}{i}{pop_suf}"
|
||||
state_dict.pop(pop_key)
|
||||
|
||||
model.load_state_dict(state_dict)
|
|
@ -1,201 +0,0 @@
|
|||
import math
|
||||
from copy import deepcopy
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class MoeExperts(nn.Module):
|
||||
"""Basic class for experts in MoE. It stores what kind of communication experts use
|
||||
to exchange tokens, how many experts in a single GPU and parallel information such as
|
||||
expert parallel size, data parallel size and their distributed communication groups.
|
||||
"""
|
||||
|
||||
def __init__(self, comm_name: str, num_experts: int):
|
||||
super().__init__()
|
||||
assert comm_name in {
|
||||
"all_to_all",
|
||||
"all_gather",
|
||||
}, "This kind of communication has not been implemented yet.\n Please use Experts build function."
|
||||
self.comm_name = comm_name
|
||||
self.num_total_experts = num_experts
|
||||
# Get the configuration of experts' deployment and parallel information from moe context
|
||||
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
||||
|
||||
|
||||
@no_shard_zero_decrator(is_replicated=False)
|
||||
class Experts(MoeExperts):
|
||||
"""A wrapper class to create experts. It will create E experts across the
|
||||
moe model parallel group, where E is the number of experts. Every expert
|
||||
is a instance of the class, 'expert' in initialization parameters.
|
||||
|
||||
Args:
|
||||
expert_cls (:class:`torch.nn.Module`): The class of all experts
|
||||
num_experts (int): The number of experts
|
||||
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
||||
"""
|
||||
|
||||
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
# Use seed to make every expert different from others
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
|
||||
|
||||
# Attach parallel information for all parameters in Experts
|
||||
for exp in self.experts:
|
||||
for param in exp.parameters():
|
||||
param.__setattr__("moe_info", self.dist_info)
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
# Split inputs for each expert
|
||||
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
||||
expert_output = []
|
||||
|
||||
# Get outputs from each expert
|
||||
for i in range(self.num_local_experts):
|
||||
expert_output.append(self.experts[i](expert_input[i]))
|
||||
|
||||
# Concatenate all outputs together
|
||||
output = torch.cat(expert_output, dim=1).contiguous()
|
||||
return output
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
assert keep_vars == False, "Only support keep_vars=False now"
|
||||
dp_rank = dist.get_rank(self.dist_info.dp_group)
|
||||
ep_rank = dist.get_rank(self.dist_info.ep_group)
|
||||
submodule_dict = dict()
|
||||
example_submodule = None
|
||||
for name, subm in self.experts.named_modules():
|
||||
if subm is self.experts:
|
||||
continue
|
||||
module_number = self.num_local_experts * ep_rank + int(name)
|
||||
submodule_dict[module_number] = subm
|
||||
example_submodule = subm
|
||||
|
||||
if dp_rank == 0:
|
||||
local_prefix = prefix + "experts."
|
||||
buffer_module = deepcopy(example_submodule)
|
||||
for i in range(self.num_total_experts):
|
||||
source_rank = i // self.num_local_experts
|
||||
current_prefix = local_prefix + str(i) + "."
|
||||
comm_module = submodule_dict.get(i, buffer_module)
|
||||
for name, param in comm_module.named_parameters():
|
||||
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
|
||||
if ep_rank == 0:
|
||||
destination[current_prefix + name] = param.data.cpu()
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class FFNExperts(MoeExperts):
|
||||
"""Use torch.bmm to speed up for multiple experts."""
|
||||
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
|
||||
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
|
||||
|
||||
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
|
||||
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
|
||||
|
||||
s1 = math.sqrt(0.1 / d_model)
|
||||
s2 = math.sqrt(0.1 / d_ff)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.trunc_normal_(self.w1, std=s1)
|
||||
nn.init.trunc_normal_(self.b1, std=s1)
|
||||
nn.init.trunc_normal_(self.w2, std=s2)
|
||||
nn.init.trunc_normal_(self.b2, std=s2)
|
||||
|
||||
self.act = nn.GELU() if activation is None else activation
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
for param in self.parameters():
|
||||
param.__setattr__("moe_info", self.dist_info)
|
||||
|
||||
def forward(self, inputs): # inputs [g, el, c, h]
|
||||
el = inputs.size(1)
|
||||
h = inputs.size(-1)
|
||||
|
||||
inputs = inputs.transpose(0, 1)
|
||||
inshape = inputs.shape
|
||||
inputs = inputs.reshape(el, -1, h)
|
||||
|
||||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||
out_act = self.act(out_ff)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
out_inter = self.drop(out_act)
|
||||
|
||||
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
outputs = self.drop(out_model) # outputs [el, gc, h]
|
||||
|
||||
outputs = outputs.reshape(inshape)
|
||||
outputs = outputs.transpose(0, 1).contiguous()
|
||||
return outputs
|
||||
|
||||
|
||||
class TPExperts(MoeExperts):
|
||||
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
|
||||
case that the number of experts can't be divide by maximum expert parallel size or
|
||||
maximum expert parallel size can't be divide by the number of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
|
||||
|
||||
assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size"
|
||||
|
||||
p_ff = d_ff // MOE_CONTEXT.max_ep_size
|
||||
|
||||
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
|
||||
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
|
||||
|
||||
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
|
||||
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
|
||||
|
||||
s1 = math.sqrt(0.1 / d_model)
|
||||
s2 = math.sqrt(0.1 / d_ff)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.trunc_normal_(self.w1, std=s1)
|
||||
nn.init.trunc_normal_(self.b1, std=s1)
|
||||
nn.init.trunc_normal_(self.w2, std=s2)
|
||||
|
||||
nn.init.trunc_normal_(self.b2, std=s2)
|
||||
|
||||
self.act = nn.GELU() if activation is None else activation
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
self.w1.__setattr__("moe_info", self.dist_info)
|
||||
self.w2.__setattr__("moe_info", self.dist_info)
|
||||
self.b1.__setattr__("moe_info", self.dist_info)
|
||||
|
||||
def forward(self, inputs): # inputs [g, e, c, h]
|
||||
e = inputs.size(1)
|
||||
h = inputs.size(-1)
|
||||
|
||||
inputs = inputs.transpose(0, 1)
|
||||
inshape = inputs.shape
|
||||
inputs = inputs.reshape(e, -1, h)
|
||||
|
||||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||
out_act = self.act(out_ff)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
out_inter = self.drop(out_act)
|
||||
|
||||
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||
outputs = self.drop(out_model) # outputs [e, gc, h]
|
||||
|
||||
outputs = outputs.reshape(inshape)
|
||||
outputs = outputs.transpose(0, 1).contiguous()
|
||||
return outputs # outputs [g, e, c, h]
|
|
@ -1,212 +0,0 @@
|
|||
import math
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.legacy.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||
from colossalai.nn.layer.moe._operation import (
|
||||
COL_MOE_KERNEL_FLAG,
|
||||
AllGather,
|
||||
AllToAll,
|
||||
MoeCombine,
|
||||
MoeDispatch,
|
||||
ReduceScatter,
|
||||
)
|
||||
from colossalai.nn.layer.moe.experts import Experts, MoeExperts
|
||||
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
|
||||
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@no_shard_zero_decrator(is_replicated=True)
|
||||
class MoeLayer(nn.Module):
|
||||
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
||||
to router all tokens, is mainly used to exchange all tokens for every expert across
|
||||
the moe tensor group by all to all communication. Then it will get the output of all
|
||||
experts and exchange the output. At last returns the output of the moe system.
|
||||
|
||||
Args:
|
||||
dim_model (int): Dimension of model.
|
||||
num_experts (int): The number of experts.
|
||||
router (MoeRouter): Instance of router used in routing.
|
||||
experts (MoeExperts): Instance of experts generated by Expert.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
self.num_experts = num_experts
|
||||
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
|
||||
self.router: MoeRouter = router
|
||||
self.experts: MoeExperts = experts
|
||||
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
||||
self.ep_group = experts.dist_info.ep_group
|
||||
self.ep_size = experts.dist_info.ep_size
|
||||
self.num_local_experts = experts.num_local_experts
|
||||
|
||||
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
|
||||
|
||||
def a2a_process(self, dispatch_data: torch.Tensor):
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
|
||||
input_shape = expert_input.shape
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = expert_output.reshape(input_shape)
|
||||
expert_output = AllToAll.apply(expert_output, self.ep_group)
|
||||
return expert_output
|
||||
|
||||
def tp_process(self, dispatch_data: torch.Tensor):
|
||||
expert_in = AllGather.apply(dispatch_data, self.ep_group)
|
||||
expert_out = self.experts(expert_in)
|
||||
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
|
||||
return expert_out
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> Tuple:
|
||||
# reshape the input tokens
|
||||
tokens = inputs.reshape(-1, self.d_model)
|
||||
|
||||
# the data type of the inputs in the gating should be fp32
|
||||
fp32_input = tokens.to(torch.float)
|
||||
fp32_weight = self.gate_weight.to(torch.float)
|
||||
gate_output = F.linear(fp32_input, fp32_weight)
|
||||
|
||||
# the result from the router
|
||||
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
||||
|
||||
if self.use_kernel:
|
||||
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
|
||||
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
||||
else:
|
||||
sec_mask_f = route_result_list[1].type_as(inputs)
|
||||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
# dispatch_data [e, c, h]
|
||||
if self.experts.comm_name == "all_to_all":
|
||||
expert_output = self.a2a_process(dispatch_data)
|
||||
elif self.experts.comm_name == "all_gather":
|
||||
expert_output = self.tp_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"This kind of communication has not been implemented yet.\n Please use Experts " "build function."
|
||||
)
|
||||
# expert_output [e, c, h]
|
||||
if self.use_kernel:
|
||||
expert_output = expert_output.reshape(-1, self.d_model)
|
||||
ans = MoeCombine.apply(expert_output, *route_result_list)
|
||||
else:
|
||||
combine_weights = route_result_list[0].type_as(inputs)
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
ans = ans.reshape(inputs.shape)
|
||||
l_aux = self.router.pop_routing_loss()
|
||||
return ans, l_aux
|
||||
|
||||
|
||||
class MoeModule(nn.Module):
|
||||
"""A class for users to create MoE modules in their models.
|
||||
|
||||
Args:
|
||||
dim_model (int): Hidden dimension of training model
|
||||
num_experts (int): The number experts
|
||||
top_k (int, optional): The number of experts for dispatchment of each token
|
||||
capacity_factor_train (float, optional): Capacity factor in routing during training
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
|
||||
'Jitter' can be found in `Switch Transformer paper`_.
|
||||
'Gaussian' can be found in `ViT-MoE paper`_.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
|
||||
More information can be found in `Microsoft paper`_.
|
||||
residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
|
||||
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
|
||||
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
|
||||
expert_args (optional): The args of expert when no instance is given
|
||||
|
||||
.. _Switch Transformer paper:
|
||||
https://arxiv.org/abs/2101.03961
|
||||
.. _ViT-MoE paper:
|
||||
https://arxiv.org/abs/2106.05974
|
||||
.. _Microsoft paper:
|
||||
https://arxiv.org/abs/2201.05596
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
num_experts: int,
|
||||
top_k: int = 1,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_policy: Optional[str] = None,
|
||||
drop_tks: bool = True,
|
||||
use_residual: bool = False,
|
||||
residual_instance: Optional[nn.Module] = None,
|
||||
expert_instance: Optional[MoeExperts] = None,
|
||||
expert_cls: Optional[Type[nn.Module]] = None,
|
||||
**expert_args,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
noisy_func = None
|
||||
if noisy_policy is not None:
|
||||
if noisy_policy == "Jitter":
|
||||
noisy_func = UniformNoiseGenerator()
|
||||
elif noisy_policy == "Gaussian":
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported input noisy policy")
|
||||
|
||||
if top_k == 1:
|
||||
moe_router_cls = Top1Router
|
||||
elif top_k == 2:
|
||||
moe_router_cls = Top2Router
|
||||
else:
|
||||
raise NotImplementedError("top_k > 2 is not supported yet")
|
||||
|
||||
self.moe_router = moe_router_cls(
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
self.use_residual = use_residual
|
||||
if use_residual:
|
||||
if residual_instance is not None:
|
||||
self.residual_module = residual_instance
|
||||
else:
|
||||
assert expert_cls is not None, "Expert class can't be None when residual instance is not given"
|
||||
self.residual_module = expert_cls(**expert_args)
|
||||
|
||||
with no_shard_zero_context():
|
||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||
|
||||
if expert_instance is not None:
|
||||
my_experts = expert_instance
|
||||
else:
|
||||
assert expert_cls is not None, "Expert class can't be None when experts instance is not given"
|
||||
my_experts = Experts(expert_cls, num_experts, **expert_args)
|
||||
|
||||
self.moe_layer = MoeLayer(
|
||||
dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
moe_output, l_aux = self.moe_layer(inputs)
|
||||
|
||||
if self.use_residual:
|
||||
residual_output = self.residual_module(inputs)
|
||||
combine_coef = self.residual_combine(inputs)
|
||||
combine_coef = F.softmax(combine_coef, dim=-1)
|
||||
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
|
||||
else:
|
||||
output = moe_output
|
||||
|
||||
return output, l_aux
|
|
@ -1,235 +0,0 @@
|
|||
import math
|
||||
from abc import ABC
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.nn.layer.moe._operation import moe_cumsum
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class MoeRouter(nn.Module, ABC):
|
||||
"""Base class for all MoE routers.
|
||||
Args:
|
||||
k_value (int): The value of top_k.
|
||||
capacity_factor_train (float): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float): Capacity factor in routing of evaluation.
|
||||
min_capacity (int): The minimum number of the capacity of each expert.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
self._routing_loss = None
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def set_routing_loss(self, aux_loss: torch.Tensor) -> None:
|
||||
assert self._routing_loss is None
|
||||
self._routing_loss = aux_loss
|
||||
|
||||
def pop_routing_loss(self) -> torch.Tensor:
|
||||
assert self._routing_loss is not None
|
||||
reservation = self._routing_loss
|
||||
self._routing_loss = None
|
||||
return reservation
|
||||
|
||||
|
||||
class Top1Router(MoeRouter):
|
||||
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More detailed function can be found in the paper about Switch Transformer
|
||||
of Google.
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
||||
select_policy (str, optional): The policy about tokens selection.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
|
||||
).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
logits = F.softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
# caculate the auxiliary loss
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce)
|
||||
self.set_routing_loss(l_aux)
|
||||
|
||||
if not self.training and not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
if self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * logits.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More detailed function can be found in the paper about ViT-MoE.
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
# inputs: [s, h]
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(logits, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = mask1 + mask2 # loss: [s, e]
|
||||
|
||||
# caculate the auxiliary loss
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(cmask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
||||
self.set_routing_loss(l_aux)
|
||||
|
||||
if not self.training and not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
rank1 = moe_cumsum(mask1) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
weight1 = mask1 * logits.type_as(inputs)
|
||||
weight2 = mask2 * logits.type_as(inputs)
|
||||
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
cb_weight = cb_weight1 + cb_weight2
|
||||
sec_mask = cb_weight.bool()
|
||||
|
||||
return cb_weight, sec_mask
|
|
@ -1,71 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .experts import FFNExperts, TPExperts
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
def half(self, memory_format=None):
|
||||
return self.data.clone()
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
|
||||
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
|
||||
`E = the number of experts`.
|
||||
|
||||
Args:
|
||||
num_experts (int): The number of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
return inputs + noisy
|
||||
|
||||
|
||||
class UniformNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
copied from mesh tensorflow:
|
||||
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
|
||||
Makes models more resilient to rounding errors introduced by bfloat16.
|
||||
This seems particularly important for logits.
|
||||
|
||||
Args:
|
||||
eps (float, optional): Epsilon in generator, defaults 1e-2.
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-2):
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(1.0 - eps, device=get_current_device()),
|
||||
high=torch.tensor(1.0 + eps, device=get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.uniform(inputs.shape)
|
||||
return inputs * noisy
|
||||
|
||||
|
||||
def autocast_softmax(logit: torch.Tensor, dim: int):
|
||||
if logit.dtype != torch.float32:
|
||||
logit = logit.float()
|
||||
return F.softmax(logit, dim=dim)
|
||||
|
||||
|
||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
mep_size = MOE_CONTEXT.max_ep_size
|
||||
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
||||
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
elif d_ff % mep_size == 0:
|
||||
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
else:
|
||||
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
|
|
@ -1 +0,0 @@
|
|||
# from .loss_moe import MoeCrossEntropyLoss, MoeLoss
|
|
@ -0,0 +1,137 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .moe_info import MoeParallelInfo
|
||||
|
||||
|
||||
def is_moe_tensor(tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check whether the given tensor is a moe tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the given tensor is a moe tensor.
|
||||
"""
|
||||
return hasattr(tensor, "moe_info")
|
||||
|
||||
|
||||
def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None:
|
||||
"""
|
||||
Set moe info for the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be set.
|
||||
moe_info (dict): The moe info to be set.
|
||||
|
||||
"""
|
||||
tensor.__setattr__("moe_info", moe_info)
|
||||
|
||||
|
||||
def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo:
|
||||
"""
|
||||
Get moe info for the given tensor.
|
||||
|
||||
Args:
|
||||
ep_size (int): The expert parallel size.
|
||||
dp_size (int): The data parallel size.
|
||||
pp_size (int): The pipeline parallel size.
|
||||
ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle.
|
||||
|
||||
Returns:
|
||||
dict: The moe info of the given tensor.
|
||||
"""
|
||||
return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size)
|
||||
|
||||
|
||||
def get_ep_group(tensor: torch.Tensor) -> ProcessGroup:
|
||||
"""
|
||||
Get the expert parallel group of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
torch.distributed.ProcessGroup: The expert parallel group of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.ep_group
|
||||
|
||||
|
||||
def get_ep_size(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the expert parallel size of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The expert parallel size of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.ep_size
|
||||
|
||||
|
||||
def get_dp_group(tensor: torch.Tensor) -> ProcessGroup:
|
||||
"""
|
||||
Get the data parallel group of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
torch.distributed.ProcessGroup: The data parallel group of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.dp_group
|
||||
|
||||
|
||||
def get_ep_rank(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the expert parallel rank of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The expert parallel rank of the given tensor.
|
||||
"""
|
||||
return dist.get_rank(get_ep_group(tensor))
|
||||
|
||||
|
||||
def get_dp_rank(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the data parallel rank of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The data parallel rank of the given tensor.
|
||||
"""
|
||||
return dist.get_rank(get_dp_group(tensor))
|
||||
|
||||
|
||||
def get_ep_group_ranks(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the expert parallel group ranks of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The expert parallel group ranks of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.ep_group_ranks
|
||||
|
||||
|
||||
def get_dp_group_ranks(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the data parallel group ranks of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The data parallel group ranks of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.dp_group_ranks
|
|
@ -0,0 +1,28 @@
|
|||
from colossalai.cluster import ProcessGroupMesh
|
||||
|
||||
|
||||
class MoeParallelInfo:
|
||||
"""Moe parallelism information, storing parallel sizes and groups."""
|
||||
|
||||
def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1):
|
||||
"""
|
||||
init MoeParallelInfo with ep_size, dp_size and pp_size
|
||||
|
||||
Args:
|
||||
ep_size (int): expert parallel size
|
||||
dp_size (int): data parallel (zero) size
|
||||
pp_size (int, optional): pipeline parallel size. Defaults to 1.
|
||||
ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
|
||||
"""
|
||||
self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size
|
||||
if ep_inside:
|
||||
self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2
|
||||
self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size)
|
||||
else:
|
||||
self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2
|
||||
self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size)
|
||||
|
||||
self.ep_group = self.pg.get_group_along_axis(self.ep_axis)
|
||||
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
|
||||
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
|
||||
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
|
|
@ -1,53 +0,0 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.utils import is_using_ddp
|
||||
|
||||
|
||||
def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
|
||||
"""Returns a parameter dictionary, the key of which is the expert parallel
|
||||
size of every parameter. Since the parameters in data parallelism is replicated
|
||||
in each GPU, we set their ep_size to 1.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
|
||||
"""
|
||||
epsize_param_dict = dict()
|
||||
for param in model.parameters():
|
||||
if not hasattr(param, "moe_info"):
|
||||
ep_size = 1 # set ep_size to 1 for dp parameters
|
||||
else:
|
||||
ep_size = param.moe_info.ep_size
|
||||
if ep_size not in epsize_param_dict:
|
||||
epsize_param_dict[ep_size] = []
|
||||
epsize_param_dict[ep_size].append(param)
|
||||
|
||||
return epsize_param_dict
|
||||
|
||||
|
||||
def sync_moe_model_param(model: nn.Module):
|
||||
"""Make sure model parameters are consistent in MoE parallel context.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
"""
|
||||
if is_using_ddp():
|
||||
param_dict = get_moe_epsize_param_dict(model)
|
||||
|
||||
# synchronize the parameters whose dp_group is the whole world
|
||||
if 1 in param_dict:
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in param_dict[1]:
|
||||
dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
for ep_size in param_dict:
|
||||
# When ep_size = world_size, communication is not needed
|
||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||
src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group)
|
||||
for param in param_dict[ep_size]:
|
||||
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
|
|
@ -8,6 +8,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -18,6 +19,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
|||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
@ -75,6 +77,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
|
@ -95,6 +98,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._local_rank = dist.get_rank(group=self.dp_pg)
|
||||
self._world_size = dist.get_world_size(group=self.dp_pg)
|
||||
|
||||
# extra dp
|
||||
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
|
||||
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
|
||||
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
|
||||
# And moe working and master param are split by extra dp pg.
|
||||
self.moe_extra_dp_pg = moe_extra_dp_process_group
|
||||
if self.moe_extra_dp_pg is not None:
|
||||
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
|
||||
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
|
||||
|
||||
# working and master params for mixed precision training
|
||||
self._working_param_groups = dict()
|
||||
self._master_param_groups_of_current_rank = dict()
|
||||
|
@ -126,6 +139,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
|
||||
self._bucket_store = BucketStore(self.dp_pg)
|
||||
|
||||
# moe param should not be stored in working_groups
|
||||
# because they have different parallel strategy
|
||||
# so we need to store them separately in param_groups
|
||||
# instead of working_groups
|
||||
moe_params = list()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
|
@ -133,6 +152,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
group_params = list()
|
||||
for param in param_group["params"]:
|
||||
if param.requires_grad:
|
||||
if self.moe_extra_dp_pg is None:
|
||||
# skip moe param
|
||||
if is_moe_tensor(param):
|
||||
moe_params.append(param)
|
||||
continue
|
||||
group_params.append(param)
|
||||
|
||||
# add the working params to working_param_groups for bookkeeping
|
||||
|
@ -146,6 +170,15 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# managed by this data parallel rank
|
||||
param_group["params"] = master_param_current_rank
|
||||
|
||||
# if there are moe params, store in addtional group in optim
|
||||
if len(moe_params) > 0:
|
||||
param_group = dict()
|
||||
for key, value in self.optim.param_groups[0].items():
|
||||
if key != "params":
|
||||
param_group[key] = value
|
||||
param_group["params"] = moe_params
|
||||
self.optim.param_groups.append(param_group)
|
||||
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
if self._overlap_communication:
|
||||
|
@ -208,13 +241,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
param.data = padding_param[: param.numel()].view(param.shape)
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
splited_params = padding_param.split(padding_param.numel() // self._world_size)
|
||||
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(param):
|
||||
splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size)
|
||||
splited_params = splited_params[self.moe_extra_dp_pg_rank]
|
||||
else:
|
||||
splited_params = padding_param.split(padding_param.numel() // self._world_size)
|
||||
splited_params = splited_params[self._local_rank]
|
||||
|
||||
# use fp32 when master_weights is True
|
||||
if self._master_weights is True:
|
||||
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
|
||||
splited_param_current_rank = splited_params.detach().float().to(device)
|
||||
else:
|
||||
splited_param_current_rank = splited_params[self._local_rank]
|
||||
splited_param_current_rank = splited_params
|
||||
|
||||
params_current_rank.append(splited_param_current_rank)
|
||||
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
|
||||
|
||||
|
@ -247,8 +287,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if self._bucket_store.num_elements_in_bucket() > 0:
|
||||
self._bucket_store.build_grad_in_bucket()
|
||||
|
||||
flat_grads = self._bucket_store.get_flatten_grad()
|
||||
flat_grads /= self._world_size
|
||||
if self.moe_extra_dp_pg is None:
|
||||
flat_grads = self._bucket_store.get_flatten_grad()
|
||||
flat_grads /= self._world_size
|
||||
else:
|
||||
# record moe and non moe param
|
||||
moe_list = []
|
||||
for param in self._bucket_store._param_list:
|
||||
moe_list.append(is_moe_tensor(param))
|
||||
|
||||
# divide them into different groups
|
||||
moe_grad_list = []
|
||||
non_moe_grad_list = []
|
||||
for grad_list in self._bucket_store._grad_in_bucket.values():
|
||||
non_moe_cur_grad = []
|
||||
moe_cur_grad = []
|
||||
for i in range(len(grad_list)):
|
||||
if moe_list[i] == True:
|
||||
moe_cur_grad.append(grad_list[i])
|
||||
else:
|
||||
non_moe_cur_grad.append(grad_list[i])
|
||||
if len(moe_cur_grad) > 0:
|
||||
moe_grad_list.append(moe_cur_grad)
|
||||
if len(non_moe_cur_grad) > 0:
|
||||
non_moe_grad_list.append(non_moe_cur_grad)
|
||||
|
||||
if len(non_moe_grad_list) > 0:
|
||||
non_moe_flat_grads = []
|
||||
for grad_list in non_moe_grad_list:
|
||||
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
|
||||
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
|
||||
non_moe_flat_grads /= self._world_size
|
||||
|
||||
if len(moe_grad_list) > 0:
|
||||
moe_flat_grads = []
|
||||
for grad_list in moe_grad_list:
|
||||
moe_flat_grads.append(_flatten_dense_tensors(grad_list))
|
||||
moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
|
||||
|
||||
# ready to add other tensors to bucket
|
||||
self._bucket_store.reset_num_elements_in_bucket()
|
||||
|
@ -256,7 +331,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
# in case of the memory being reused in the default stream
|
||||
flat_grads.record_stream(stream)
|
||||
if self.moe_extra_dp_pg is None:
|
||||
flat_grads.record_stream(stream)
|
||||
else:
|
||||
if len(non_moe_grad_list) > 0:
|
||||
non_moe_flat_grads.record_stream(stream)
|
||||
if len(moe_grad_list) > 0:
|
||||
moe_flat_grads.record_stream(stream)
|
||||
# waiting for ops in the default stream finishing
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
else:
|
||||
|
@ -265,49 +346,108 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
with torch.cuda.stream(stream):
|
||||
group_id = self._bucket_store.current_group_id
|
||||
|
||||
grad_dtype = flat_grads.dtype
|
||||
if self._communication_dtype is not None:
|
||||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
if self.moe_extra_dp_pg is None:
|
||||
grad_dtype = flat_grads.dtype
|
||||
if self._communication_dtype is not None:
|
||||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
|
||||
if not self._partition_grads:
|
||||
dist.all_reduce(flat_grads, group=self.dp_pg)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
if self.moe_extra_dp_pg is None:
|
||||
dist.all_reduce(flat_grads, group=self.dp_pg)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
|
||||
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
|
||||
grad_in_bucket = self._bucket_store.get_grad()
|
||||
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
|
||||
grad_in_bucket = self._bucket_store.get_grad()
|
||||
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
||||
|
||||
for rank, grad_list in grad_in_bucket.items():
|
||||
sync_tensor(flat_grads_per_rank[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
if (
|
||||
len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id))
|
||||
< self._world_size
|
||||
):
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
||||
# sync extra zero group
|
||||
else:
|
||||
# sync non moe param in global dp group
|
||||
if len(non_moe_grad_list) > 0:
|
||||
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
|
||||
flat_grads_per_rank = non_moe_flat_grads.split(
|
||||
non_moe_flat_grads.numel() // self._world_size
|
||||
)
|
||||
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
|
||||
|
||||
# sync moe param only in zero group
|
||||
if len(moe_grad_list) > 0:
|
||||
dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg)
|
||||
flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size)
|
||||
self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id)
|
||||
|
||||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||
if self.moe_extra_dp_pg is None:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||
|
||||
if recieved_grad.dtype != grad_dtype:
|
||||
recieved_grad = recieved_grad.to(grad_dtype)
|
||||
if recieved_grad.dtype != grad_dtype:
|
||||
recieved_grad = recieved_grad.to(grad_dtype)
|
||||
|
||||
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
||||
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
|
||||
for grad in grad_in_bucket_current_rank:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
|
||||
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
||||
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
|
||||
else:
|
||||
# categorize moe and non moe param
|
||||
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
||||
moe_grad_in_bucket_current_rank = []
|
||||
non_moe_grad_in_bucket_current_rank = []
|
||||
for idx, grad in enumerate(grad_in_bucket_current_rank):
|
||||
if moe_list[idx] == True:
|
||||
moe_grad_in_bucket_current_rank.append(grad)
|
||||
else:
|
||||
non_moe_grad_in_bucket_current_rank.append(grad)
|
||||
|
||||
if len(non_moe_grad_list) > 0:
|
||||
flat_grads_list = list(
|
||||
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
|
||||
)
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||
self._update_partitoned_grad(
|
||||
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
|
||||
)
|
||||
|
||||
if len(moe_grad_list) > 0:
|
||||
flat_grads_list = list(
|
||||
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
|
||||
)
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
|
||||
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
||||
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
|
||||
for split_recieved_grad in recieved_grad:
|
||||
split_recieved_grad = _unflatten_dense_tensors(
|
||||
split_recieved_grad, moe_grad_in_bucket_current_rank
|
||||
)
|
||||
for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._add_grad(real_grad, param_slice, group_id, param_id)
|
||||
|
||||
self._bucket_store.reset()
|
||||
|
||||
def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None:
|
||||
for rank, grad_list in enumerate(origin_grad_list):
|
||||
sync_tensor(flat_grad_list[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._add_grad(grad, self._world_size, group_id, param_id, rank)
|
||||
|
||||
def _update_partitoned_grad(
|
||||
self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
|
||||
) -> None:
|
||||
sync_tensor(flat_grad, origin_grad_list)
|
||||
for grad in origin_grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._add_grad(grad, partition_num, group_id, param_id)
|
||||
|
||||
def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
||||
|
||||
def _add_to_bucket(self, param, group_id):
|
||||
param_size = param.numel()
|
||||
|
||||
|
@ -424,13 +564,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# else the splited grad should be attached to the splited param
|
||||
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
|
||||
if len(grads) > 0:
|
||||
real_working_params[group_id].append(working_param)
|
||||
# moe hybrid zero
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
||||
real_working_params[group_id].append(working_param)
|
||||
if self._partition_grads:
|
||||
grad = grads
|
||||
else:
|
||||
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
||||
grad = grads[
|
||||
self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice
|
||||
]
|
||||
grad = flatten(grad)
|
||||
else:
|
||||
real_working_params[group_id].append(working_param)
|
||||
grad = grads[grad_index]
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
grad = (
|
||||
grads[grad_index].to(splited_param.dtype).to(splited_param.device)
|
||||
if self._master_weights
|
||||
else grads[grad_index]
|
||||
)
|
||||
if self._master_weights:
|
||||
grad = grad.to(splited_param.dtype).to(splited_param.device)
|
||||
splited_param.grad = grad
|
||||
grad_partition_groups.append(grad)
|
||||
real_master_params[group_id].append(splited_param)
|
||||
|
@ -449,24 +599,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||
|
||||
# TODO: we should store master param for ep
|
||||
if len(self.param_groups) > len(self._working_param_groups):
|
||||
for param in self.param_groups[-1]["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
param.grad = param.grad.to(torch.float32)
|
||||
|
||||
# update the parameters
|
||||
self.optim.step()
|
||||
|
||||
# release the moe gradm
|
||||
if len(self.param_groups) > len(self._working_param_groups):
|
||||
for param in self.param_groups[-1]["params"]:
|
||||
param.grad = None
|
||||
param.data = param.data.to(self._dtype)
|
||||
|
||||
# release the grad
|
||||
grad_partition_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
# update working partition updated by the current rank
|
||||
# dtype = real_working_params[0][0].dtype
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
|
||||
for _ in range(self.moe_extra_dp_pg_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg)
|
||||
else:
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
|
||||
for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
|
@ -488,7 +657,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
@ -596,10 +764,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
working_param = self._param_store.master_to_working_param[id(param)]
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg)
|
||||
else:
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
|
||||
param_state = (
|
||||
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
|
@ -624,8 +798,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
||||
v_list = v.split(v.numel() // self.moe_extra_dp_pg_size)
|
||||
zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone()
|
||||
else:
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
|
||||
|
@ -656,8 +834,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
||||
state_tensor = [
|
||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
||||
]
|
||||
dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg)
|
||||
else:
|
||||
state_tensor = [
|
||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
|
||||
state_tensor = (
|
||||
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
|
@ -688,7 +874,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
working_param = p.data.view(-1)
|
||||
if padding_size > 0:
|
||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
|
||||
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
||||
else:
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.working_to_master_param
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
## OpenMoE
|
||||
[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Installation
|
||||
|
||||
Please install the latest ColossalAI from source.
|
||||
|
||||
```bash
|
||||
CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
|
||||
```
|
||||
|
||||
Then install dependencies.
|
||||
|
||||
```bash
|
||||
cd ColossalAI/examples/language/openmoe
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
|
||||
|
||||
### 2. Install kernels (Optional)
|
||||
|
||||
We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware.
|
||||
```
|
||||
# install triton via pip
|
||||
pip install triton
|
||||
|
||||
# install flash attention via pip
|
||||
pip install flash-attn==2.0.5
|
||||
|
||||
# install apex from source
|
||||
git clone https://github.com/NVIDIA/apex.git
|
||||
cd apex
|
||||
git checkout 741bdf50825a97664db08574981962d66436d16a
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext"
|
||||
```
|
||||
|
||||
### 3. Train
|
||||
Yon can use colossalai run to launch single-node training:
|
||||
```bash
|
||||
colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
Yon can also use colossalai run to launch multi-nodes training:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
|
||||
Here is a sample hostfile:
|
||||
|
||||
```text
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
|
||||
The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password.
|
||||
|
||||
Here is details about CLI arguments:
|
||||
|
||||
- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE.
|
||||
- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training.
|
||||
- Output path: `--output_path`. The path to save your model. The default value is `./outputs`.
|
||||
- Number of epochs: `--num_epochs`. The default value is 1.
|
||||
- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1.
|
||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported.
|
||||
- Max length: `--max_length`. Max sequence length. Default to 2048.
|
||||
- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it.
|
||||
- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`.
|
||||
- Learning rate: `--lr`. The default value is 1e-5.
|
||||
- Weight decay: `--weight_decay`. The default value is 0.
|
||||
- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero.
|
||||
- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4.
|
||||
- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.
|
||||
- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed.
|
||||
- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details.
|
||||
- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details.
|
||||
- Label smoothing: `--label_smoothing`. Label smoothing.
|
||||
- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor.
|
||||
Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling.
|
||||
- Load balance interval: `--load_balance_interval`. Expert load balance interval.
|
||||
- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training.
|
||||
|
||||
### 4. Shell Script Examples
|
||||
|
||||
For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training
|
||||
OpenMoE.
|
||||
|
||||
#### a. Running environment
|
||||
This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink.
|
||||
|
||||
#### b. Running command
|
||||
We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args.
|
||||
|
||||
```bash
|
||||
bash train.sh
|
||||
```
|
||||
|
||||
#### c. Multi-Nodes Training
|
||||
|
||||
To run on multi-nodes, you can modify the script as:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
train.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
|
||||
## Reference
|
||||
```
|
||||
@article{bian2021colossal,
|
||||
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
||||
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
||||
journal={arXiv preprint arXiv:2110.14883},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{openmoe2023,
|
||||
author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You},
|
||||
title = {OpenMoE: Open Mixture-of-Experts Language Models},
|
||||
year = {2023},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,296 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from huggingface_hub import snapshot_download
|
||||
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
|
||||
from model.openmoe_policy import OpenMoeForCausalLMPolicy
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import T5Tokenizer
|
||||
from transformers.models.llama import LlamaConfig
|
||||
from utils import PerformanceEvaluator, get_model_numel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
|
||||
ckpt_path = snapshot_download(repo_name)
|
||||
# single ckpt
|
||||
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
|
||||
# shard ckpt
|
||||
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
||||
else:
|
||||
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
||||
booster.load_model(model, ckpt_path)
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(
|
||||
self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None
|
||||
):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
if os.path.exists("./mock_data.json"):
|
||||
self.input_ids = []
|
||||
self.attention_mask = []
|
||||
with open("./mock_data.json", "r") as f:
|
||||
data = json.load(f)
|
||||
for v in data.values():
|
||||
d = v["text"]
|
||||
encode = tokenizer(
|
||||
"<pad>" + d,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
self.input_ids.append(encode["input_ids"])
|
||||
self.attention_mask.append(encode["attention_mask"])
|
||||
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
|
||||
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
|
||||
repeat_times = num_samples // self.input_ids.shape[0] + 1
|
||||
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
|
||||
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
|
||||
else:
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="base",
|
||||
choices=["base", "8b"],
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per dp group) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="sequence length for the training dataloader.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
help="parallel plugin",
|
||||
)
|
||||
# hybrid plugin
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pp size")
|
||||
parser.add_argument("--dp_size", type=int, default=1, help="dp size")
|
||||
parser.add_argument("--ep_size", type=int, default=2, help="ep size")
|
||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")
|
||||
parser.add_argument("--extra_dp_size", type=int, default=1)
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",
|
||||
)
|
||||
# bench
|
||||
parser.add_argument("--warmup", type=int, default=20)
|
||||
parser.add_argument("--active", type=int, default=20)
|
||||
# load balance
|
||||
parser.add_argument("--load_balance", action="store_true")
|
||||
|
||||
# overlap
|
||||
parser.add_argument("--overlap_alltoall", action="store_true")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
hybrid_dict = {
|
||||
"tp_size": 1,
|
||||
"custom_policy": OpenMoeForCausalLMPolicy(),
|
||||
"enable_fused_normalization": args.use_kernel,
|
||||
"enable_jit_fused": args.use_kernel,
|
||||
"precision": "bf16",
|
||||
"zero_stage": args.zero_stage,
|
||||
}
|
||||
mgr_dict = {
|
||||
"seed": 42,
|
||||
}
|
||||
if args.plugin == "ep":
|
||||
dp_size = dist.get_world_size()
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
elif args.plugin == "ep_zero":
|
||||
dp_size = dist.get_world_size()
|
||||
use_ep_inside = False
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
extra_dp_size=args.extra_dp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size // args.extra_dp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**mgr_dict,
|
||||
)
|
||||
elif args.plugin == "hybrid":
|
||||
dp_size = dist.get_world_size() // args.pp_size
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=args.pp_size,
|
||||
zero_stage=args.zero_stage,
|
||||
microbatch_size=args.microbatch_size,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=args.dp_size,
|
||||
fixed_ep_size=args.ep_size,
|
||||
fixed_pp_size=args.pp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin}")
|
||||
|
||||
# Build OpenMoe model
|
||||
repo_name = "hpcaitech/openmoe-" + args.model_name
|
||||
config = LlamaConfig.from_pretrained(repo_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_load_balance=args.load_balance,
|
||||
enable_kernel=args.use_kernel,
|
||||
enable_comm_overlap=args.overlap_alltoall,
|
||||
)
|
||||
with skip_init():
|
||||
model = OpenMoeForCausalLM(config)
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
|
||||
max_length=args.seq_length,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
performance_evaluator = PerformanceEvaluator(
|
||||
model_numel,
|
||||
enable_grad_checkpoint=True,
|
||||
ignore_steps=args.warmup,
|
||||
dp_world_size=dp_size,
|
||||
)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
load_ckpt(repo_name, model, booster)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Start finetuning
|
||||
coordinator.print_on_master(f"Start training")
|
||||
model.train()
|
||||
train_dataloader_iter = iter(dataloader)
|
||||
total_len = len(train_dataloader_iter) - 1
|
||||
exmaple_data = next(train_dataloader_iter)
|
||||
with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:
|
||||
for step in pbar:
|
||||
performance_evaluator.on_step_start(step)
|
||||
if use_pipeline:
|
||||
# Forward pass
|
||||
outputs = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
# Forward pass
|
||||
data = next(train_dataloader_iter)
|
||||
data = move_to_cuda(data, torch.cuda.current_device())
|
||||
outputs = model(**data)
|
||||
loss = outputs["loss"]
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(exmaple_data["input_ids"])
|
||||
if (step == args.warmup // 2) and args.load_balance:
|
||||
coordinator.print_on_master(f"Apply load balance")
|
||||
apply_load_balance(model, optimizer)
|
||||
performance_evaluator.on_fit_end()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,78 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
NUM_GPU=8
|
||||
MODEL="8b"
|
||||
SEQ_LENGTH=2048
|
||||
WARMUP=20
|
||||
ACTIVE=4
|
||||
|
||||
# HACK: make model importable
|
||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
||||
if [ -z ${PYTHONPATH+x} ]; then
|
||||
export PYTHONPATH=$example_dir
|
||||
else
|
||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
||||
fi
|
||||
|
||||
|
||||
# ep
|
||||
echo -e "\n\n Naive EP \n\n"
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 8 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep \
|
||||
--zero_stage 2
|
||||
|
||||
|
||||
# ep_zero
|
||||
echo -e "\n\n EP-ZERO \n\n"
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 16 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep_zero \
|
||||
--use_kernel \
|
||||
--extra_dp_size 2 \
|
||||
--zero_stage 1 \
|
||||
--load_balance
|
||||
|
||||
echo -e "\n\n EP-ZERO + Overlap \n\n"
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 16 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep_zero \
|
||||
--use_kernel \
|
||||
--extra_dp_size 2 \
|
||||
--zero_stage 1 \
|
||||
--load_balance \
|
||||
--overlap_alltoall
|
||||
|
||||
|
||||
# hybrid
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 128 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--use_kernel \
|
||||
--plugin hybrid \
|
||||
--pp_size 2 \
|
||||
--dp_size 1 \
|
||||
--ep_size 4 \
|
||||
--zero_stage 1 \
|
||||
--microbatch_size 32
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
NUM_GPU=8
|
||||
MODEL="8b"
|
||||
SEQ_LENGTH=2048
|
||||
WARMUP=20
|
||||
ACTIVE=4
|
||||
|
||||
# HACK: make model importable
|
||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
||||
if [ -z ${PYTHONPATH+x} ]; then
|
||||
export PYTHONPATH=$example_dir
|
||||
else
|
||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
||||
fi
|
||||
|
||||
|
||||
# ep
|
||||
echo -e "\n\n Naive EP \n\n"
|
||||
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 12 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep \
|
||||
--zero_stage 2
|
||||
|
||||
|
||||
# ep_zero
|
||||
echo -e "\n\n EP-ZERO \n\n"
|
||||
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 20 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep_zero \
|
||||
--use_kernel \
|
||||
--extra_dp_size 2 \
|
||||
--zero_stage 1 \
|
||||
--load_balance \
|
||||
--overlap_alltoall
|
|
@ -0,0 +1,139 @@
|
|||
import argparse
|
||||
import functools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import tqdm
|
||||
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers.models.llama import LlamaConfig
|
||||
from utils import PerformanceEvaluator, get_model_numel
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def fsdp_main(rank, world_size, args):
|
||||
# initialize the process group
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
MOE_MANAGER.setup(seed=42, parallel=None)
|
||||
|
||||
dp_size = dist.get_world_size()
|
||||
dataset = RandomDataset(
|
||||
max_length=args.seq_length,
|
||||
num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
|
||||
)
|
||||
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
|
||||
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
|
||||
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_load_balance=False,
|
||||
enable_kernel=False,
|
||||
enable_comm_overlap=False,
|
||||
)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
model = OpenMoeForCausalLM(config)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={
|
||||
OpenMoeDecoderLayer,
|
||||
},
|
||||
)
|
||||
model = FSDP(
|
||||
model,
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.bfloat16,
|
||||
buffer_dtype=torch.bfloat16,
|
||||
),
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
device_id=torch.cuda.current_device(),
|
||||
)
|
||||
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
|
||||
model.train()
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
performance_evaluator = PerformanceEvaluator(
|
||||
model_numel,
|
||||
enable_grad_checkpoint=True,
|
||||
ignore_steps=args.warmup,
|
||||
dp_world_size=dist.get_world_size(),
|
||||
)
|
||||
|
||||
for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
|
||||
performance_evaluator.on_step_start(step)
|
||||
input_ids, attention_mask, labels = (
|
||||
data["input_ids"].cuda(),
|
||||
data["attention_mask"].cuda(),
|
||||
data["labels"].cuda(),
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=attention_mask,
|
||||
chunk_head=False,
|
||||
)
|
||||
loss = output["loss"]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
performance_evaluator.on_step_end(input_ids)
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="base",
|
||||
choices=["base", "8b"],
|
||||
help="base or 8b",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--seq_length", type=int, default=2048)
|
||||
parser.add_argument("--warmup", type=int, default=20)
|
||||
parser.add_argument("--active", type=int, default=20)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
fsdp_main(local_rank, world_size, args)
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
MODEL="8b"
|
||||
BATCH_SIZE=1
|
||||
SEQ_LENGTH=2048
|
||||
WARMUP=8
|
||||
ACTIVE=4
|
||||
|
||||
# HACK: make model importable
|
||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
||||
if [ -z ${PYTHONPATH+x} ]; then
|
||||
export PYTHONPATH=$example_dir
|
||||
else
|
||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
||||
fi
|
||||
|
||||
# single node
|
||||
torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size $BATCH_SIZE \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE
|
||||
|
||||
# multi node
|
||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \
|
||||
$example_dir/benchmark/benchmark_fsdp.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size $BATCH_SIZE \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE
|
|
@ -0,0 +1,2 @@
|
|||
host1
|
||||
host2
|
|
@ -0,0 +1,126 @@
|
|||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = "Model param count: "
|
||||
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
if model_param >= B:
|
||||
outputs += f"{model_param / B:.2f} B\n"
|
||||
elif model_param >= M:
|
||||
outputs += f"{model_param / M:.2f} M\n"
|
||||
elif model_param >= K:
|
||||
outputs += f"{model_param / K:.2f} K\n"
|
||||
else:
|
||||
outputs += f"{model_param}\n"
|
||||
logger.info(outputs, ranks=[0])
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> None:
|
||||
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return model_param
|
||||
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
if y == 0:
|
||||
return float("inf")
|
||||
elif y == float("inf"):
|
||||
return float("nan")
|
||||
return x / y
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
|
||||
def end(self) -> None:
|
||||
assert self.start_time is not None
|
||||
self.duration += time() - self.start_time
|
||||
self.start_time = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.0
|
||||
|
||||
|
||||
class PerformanceEvaluator:
|
||||
"""
|
||||
Callback for valuate the performance of the model.
|
||||
Args:
|
||||
actor_num_params: The number of parameters of the actor model.
|
||||
critic_num_params: The number of parameters of the critic model.
|
||||
initial_model_num_params: The number of parameters of the initial model.
|
||||
reward_model_num_params: The number of parameters of the reward model.
|
||||
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_numel: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_steps: int = 0,
|
||||
dp_world_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_numel = model_numel
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_steps = ignore_steps
|
||||
self.dp_world_size = dp_world_size
|
||||
self.world_size = dist.get_world_size()
|
||||
self.disable: bool = False
|
||||
self.timer = Timer()
|
||||
self.num_samples: int = 0
|
||||
self.flop: int = 0
|
||||
|
||||
def on_step_start(self, step: int) -> None:
|
||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
self.timer.start()
|
||||
|
||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
self.timer.end()
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
self.num_samples += batch_size
|
||||
self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)))
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
|
||||
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
||||
mp_world_size = self.world_size // self.dp_world_size
|
||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||
f"avg_throughput: {avg_throughput}")
|
||||
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")
|
|
@ -0,0 +1,57 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
|
||||
from transformers import T5Tokenizer
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def inference(args):
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
if args.model == "test":
|
||||
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
|
||||
set_openmoe_args(config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_kernel=True)
|
||||
model = OpenMoeForCausalLM(config)
|
||||
else:
|
||||
config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}")
|
||||
set_openmoe_args(config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_kernel=False)
|
||||
model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config)
|
||||
model = model.eval().bfloat16()
|
||||
model = model.to(torch.cuda.current_device())
|
||||
|
||||
input_str = """```
|
||||
y = list(map(int, ['1', 'hello', '2']))
|
||||
```
|
||||
What error does this program produce?
|
||||
ValueError: invalid literal for int() with base 10: 'hello'
|
||||
|
||||
```
|
||||
sum = 0
|
||||
for i in range(100):
|
||||
sum += i
|
||||
```
|
||||
What is the value of sum immediately after the 10th time line 3 is executed?"""
|
||||
|
||||
# print("model config: ", model.config)
|
||||
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=False)
|
||||
input_ids = input_ids.input_ids.to(torch.cuda.current_device())
|
||||
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64)
|
||||
out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
|
||||
print(f"output: \n{out}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
inference(args)
|
|
@ -0,0 +1 @@
|
|||
python infer.py --model "base"
|
|
@ -0,0 +1,224 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 Google LLC and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert T5X checkpoint to PyTorch
|
||||
|
||||
Steps:
|
||||
- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install
|
||||
- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example:
|
||||
`gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/`
|
||||
- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use
|
||||
https://huggingface.co/google/t5-v1_1-small/blob/main/config.json
|
||||
- Convert:
|
||||
```
|
||||
python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\
|
||||
--pytorch_dump_path=$HOME/t5_1_1_small_pt
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
|
||||
import torch
|
||||
from flax import traverse_util
|
||||
from modeling_openmoe import OpenMoeForCausalLM
|
||||
from t5x import checkpoints
|
||||
from transformers import LlamaConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
|
||||
"""Returns the KOQV parameters of (self-)attention. Does not transpose."""
|
||||
k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"]
|
||||
o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"]
|
||||
q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"]
|
||||
v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"]
|
||||
return k, o, q, v
|
||||
|
||||
|
||||
def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False):
|
||||
"""Returns the MLP parameters of a layer. Does not transpose."""
|
||||
if split_mlp_wi:
|
||||
wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"]
|
||||
wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"]
|
||||
wi = (wi_0, wi_1)
|
||||
else:
|
||||
wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"]
|
||||
|
||||
wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"]
|
||||
return wi, wo
|
||||
|
||||
|
||||
def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False):
|
||||
"""Returns the MLP parameters of a layer. Does not transpose."""
|
||||
if split_mlp_wi:
|
||||
wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"]
|
||||
wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"]
|
||||
wi = (wi_0, wi_1)
|
||||
else:
|
||||
wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"]
|
||||
|
||||
wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"]
|
||||
return wi, wo
|
||||
|
||||
|
||||
def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False):
|
||||
"""Returns the MLP parameters of a layer. Does not transpose."""
|
||||
if split_mlp_wi:
|
||||
wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"]
|
||||
wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"]
|
||||
wi = (wi_0, wi_1)
|
||||
else:
|
||||
wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"]
|
||||
|
||||
wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"]
|
||||
return wi, wo
|
||||
|
||||
|
||||
def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False):
|
||||
"""Returns the MLP parameters of a layer. Does not transpose."""
|
||||
return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"]
|
||||
|
||||
|
||||
def t5x_layer_norm_lookup(params, i, prefix, layer_name):
|
||||
"""Returns the layer norm param of a layer."""
|
||||
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
|
||||
|
||||
|
||||
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int):
|
||||
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
|
||||
old = traverse_util.flatten_dict(variables["target"])
|
||||
old = {"/".join(k): v for k, v in old.items()}
|
||||
|
||||
# v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi
|
||||
split_mlp_wi = True
|
||||
print("Split MLP:", split_mlp_wi)
|
||||
|
||||
new = collections.OrderedDict()
|
||||
print(old.keys())
|
||||
for key, value in old.items():
|
||||
print(f"{key}: {value.shape}")
|
||||
|
||||
# Shared embeddings.
|
||||
new["model.embed_tokens.weight"] = old["token_embedder/embedding"]
|
||||
|
||||
# Decoder.
|
||||
for i in range(num_layers):
|
||||
# Block i, layer 0 (Self Attention).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
|
||||
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
|
||||
new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm
|
||||
new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T
|
||||
new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T
|
||||
new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T
|
||||
new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T
|
||||
|
||||
# Block i, layer 2 (MLP).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
|
||||
new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm
|
||||
|
||||
if (i + 1) % moe_interval == 0:
|
||||
# moe
|
||||
gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"model.layers.{i}.mlp.gate_weight"] = gate.T
|
||||
wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0]
|
||||
new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1]
|
||||
new[f"model.layers.{i}.mlp.experts.wo"] = wo
|
||||
# extra
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm")
|
||||
new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm
|
||||
wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T
|
||||
new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T
|
||||
new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T
|
||||
else:
|
||||
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T
|
||||
new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T
|
||||
new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T
|
||||
|
||||
new["model.norm.weight"] = old["decoder/decoder_norm/scale"]
|
||||
|
||||
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
|
||||
if "decoder/logits_dense/kernel" in old:
|
||||
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
|
||||
|
||||
return new
|
||||
|
||||
|
||||
def make_state_dict(converted_params):
|
||||
"""Prepares a state dict for the PyTorch model."""
|
||||
# Make a state dict with torch tensors.
|
||||
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
|
||||
"""Replaces the params in model witht the T5X converted params."""
|
||||
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
||||
converted = convert_t5x_to_pytorch(variables,
|
||||
num_layers=config.num_hidden_layers,
|
||||
moe_interval=config.moe_layer_interval)
|
||||
state_dict = make_state_dict(converted)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
|
||||
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
|
||||
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
|
||||
# Initialise PyTorch model
|
||||
config = LlamaConfig.from_json_file(config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
|
||||
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
|
||||
model = OpenMoeForCausalLM(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Verify that we can load the checkpoint.
|
||||
model.from_pretrained(pytorch_dump_path)
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
|
||||
# Required parameters
|
||||
parser.add_argument("--t5x_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the T5X checkpoint.")
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
|
@ -0,0 +1 @@
|
|||
python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"architectures": [
|
||||
"OpenMoeForCausalLM"
|
||||
],
|
||||
"intermediate_size": 8192,
|
||||
"hidden_size": 2048,
|
||||
"num_hidden_layers": 24,
|
||||
"head_dim": 128,
|
||||
"num_attention_heads": 24,
|
||||
"dropout_rate": 0.0,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"vocab_size": 256384,
|
||||
"hidden_act": "swiglu",
|
||||
"num_experts": 32,
|
||||
"topk": 2,
|
||||
"capacity_factor_train": 1.25,
|
||||
"capacity_factor_eval": 2.0,
|
||||
"min_capacity": 4,
|
||||
"noisy_policy": null,
|
||||
"drop_tks": true,
|
||||
"expert_parallel": null,
|
||||
"gated": true,
|
||||
"moe_layer_interval": 6
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"architectures": [
|
||||
"OpenMoeForCausalLM"
|
||||
],
|
||||
"intermediate_size": 2048,
|
||||
"hidden_size": 768,
|
||||
"num_hidden_layers": 12,
|
||||
"head_dim": 64,
|
||||
"num_attention_heads": 12,
|
||||
"dropout_rate": 0.0,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"vocab_size": 256384,
|
||||
"hidden_act": "swiglu",
|
||||
"num_experts": 16,
|
||||
"topk": 2,
|
||||
"capacity_factor_train": 1.25,
|
||||
"capacity_factor_eval": 2.0,
|
||||
"min_capacity": 4,
|
||||
"noisy_policy": null,
|
||||
"drop_tks": true,
|
||||
"expert_parallel": null,
|
||||
"gated": true,
|
||||
"moe_layer_interval": 4
|
||||
}
|
|
@ -0,0 +1,562 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel
|
||||
|
||||
__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
|
||||
|
||||
|
||||
class OpenMoePolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="pre_extra_mlp_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OpenMoeDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OpenMoeModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
raise NotImplementedError("Flash attention has already been replaced in openmoe.")
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "OpenMoeModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=model_cls)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "OpenMoeModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
@staticmethod
|
||||
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
|
||||
"""Divide layers into stages
|
||||
|
||||
"""
|
||||
if num_layers == 24 and num_stages == 4:
|
||||
return [7, 7, 7, 3]
|
||||
elif num_layers == 24 and num_stages == 2:
|
||||
return [15, 9]
|
||||
elif num_layers == 12 and num_stages == 4:
|
||||
return [5, 5, 5, 1]
|
||||
elif num_layers == 12 and num_stages == 2:
|
||||
return [8, 4]
|
||||
else:
|
||||
print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy")
|
||||
return Policy.distribute_layers(num_layers, num_stages)
|
||||
|
||||
|
||||
class OpenMoeModelPolicy(OpenMoePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=OpenMoeModel,
|
||||
new_forward=OpenMoePipelineForwards.openmoe_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
return []
|
||||
|
||||
|
||||
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
OpenMoeForCausalLM:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=OpenMoeForCausalLM,
|
||||
new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1):
|
||||
# tie weights
|
||||
return [{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}]
|
||||
return []
|
||||
|
||||
|
||||
class OpenMoePipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def openmoe_model_forward(
|
||||
self: OpenMoeModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
past_router_aux_loss: Optional[torch.FloatTensor] = None,
|
||||
past_router_z_loss: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# reset moe loss for different data
|
||||
MOE_MANAGER.reset_loss()
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
# concat past losses with current ones
|
||||
router_aux_loss, router_z_loss = MOE_MANAGER.get_loss()
|
||||
if past_router_aux_loss is not None and past_router_z_loss is not None:
|
||||
router_aux_loss = past_router_aux_loss + router_aux_loss
|
||||
router_z_loss = past_router_z_loss + router_z_loss
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple([
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
router_aux_loss,
|
||||
router_z_loss,
|
||||
])
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"router_aux_loss": router_aux_loss,
|
||||
"router_z_loss": router_z_loss,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def llama_for_causal_lm_forward(
|
||||
self: OpenMoeForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
chunk_head: Optional[bool] = True,
|
||||
past_router_aux_loss: Optional[torch.FloatTensor] = None,
|
||||
past_router_z_loss: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = OpenMoePipelineForwards.openmoe_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
past_router_aux_loss=past_router_aux_loss,
|
||||
past_router_z_loss=past_router_z_loss,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
(
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
all_hidden_states,
|
||||
attentions,
|
||||
router_aux_loss,
|
||||
router_z_loss,
|
||||
) = outputs
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
|
||||
loss = None
|
||||
# if no training, just do forward
|
||||
if labels is None:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
# the vocab size for openmoe is 30w+
|
||||
# which causes great activation memory in training, up to 20G for one sequence
|
||||
# so we use chunk and checkpoint to reduce memory
|
||||
else:
|
||||
if chunk_head == True:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
logits = module(inputs[0])
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous().float()
|
||||
shift_labels = inputs[1][..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss = self._calculate_loss(shift_logits, shift_labels)
|
||||
return loss
|
||||
|
||||
return custom_forward
|
||||
|
||||
aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
|
||||
loss = aux_loss + z_loss
|
||||
for batch_idx in range(hidden_states.shape[0]):
|
||||
loss = loss + torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.lm_head),
|
||||
hidden_states[batch_idx:batch_idx + 1, :],
|
||||
labels[batch_idx:batch_idx + 1, :],
|
||||
)
|
||||
logits = None
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
|
||||
loss = aux_loss + z_loss
|
||||
loss = loss + self._calculate_loss(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs["hidden_states"]
|
||||
router_aux_loss = outputs["router_aux_loss"]
|
||||
router_z_loss = outputs["router_z_loss"]
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_aux_loss": router_aux_loss,
|
||||
"past_router_z_loss": router_z_loss,
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
colossalai >= 0.3.3
|
||||
torch >= 1.8.1
|
||||
transformers >= 4.20.0
|
||||
sentencepiece
|
||||
datasets
|
|
@ -0,0 +1,37 @@
|
|||
pip install -r requirements.txt
|
||||
|
||||
# inference
|
||||
python infer.py --model "test"
|
||||
|
||||
# train
|
||||
torchrun --standalone --nproc_per_node 4 train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name "test" \
|
||||
--plugin "ep" \
|
||||
--batch_size 1
|
||||
|
||||
torchrun --standalone --nproc_per_node 4 train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name "test" \
|
||||
--plugin "ep_zero" \
|
||||
--batch_size 1 \
|
||||
--zero_stage 1 \
|
||||
--extra_dp_size 2 \
|
||||
|
||||
torchrun --standalone --nproc_per_node 4 train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name "test" \
|
||||
--plugin "ep_zero" \
|
||||
--batch_size 1 \
|
||||
--zero_stage 2 \
|
||||
--extra_dp_size 2 \
|
||||
|
||||
torchrun --standalone --nproc_per_node 4 train.py \
|
||||
--model_name "test" \
|
||||
--plugin "hybrid" \
|
||||
--num_epoch 1 \
|
||||
--pp_size 2 \
|
||||
--dp_size 1 \
|
||||
--ep_size 2 \
|
||||
--zero_stage 1 \
|
||||
--batch_size 1
|
|
@ -0,0 +1,377 @@
|
|||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
|
||||
from model.openmoe_policy import OpenMoeForCausalLMPolicy
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import T5Tokenizer
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
|
||||
ckpt_path = snapshot_download(repo_name)
|
||||
# single ckpt
|
||||
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
|
||||
# shard ckpt
|
||||
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
||||
else:
|
||||
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
||||
booster.load_model(model, ckpt_path)
|
||||
|
||||
|
||||
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
|
||||
texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
|
||||
data = tokenizer(
|
||||
texts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="base",
|
||||
choices=["base", "8b", "test"],
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["ep", "ep_zero", "hybrid"],
|
||||
help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./outputs",
|
||||
help="The path of your saved model after finetuning.",
|
||||
)
|
||||
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size (per dp group) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_interval",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=" The interval (steps) of saving checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp32", "bf16", "fp16"],
|
||||
help="The mixed precision training.",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="yizhongw/self_instruct",
|
||||
help="dataset name from `datasets` repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
type=str,
|
||||
default="super_natural_instructions",
|
||||
help="task of corresponding dataset.",
|
||||
)
|
||||
|
||||
# optim
|
||||
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
|
||||
# zero stage for all plugins
|
||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
||||
# ep_zero plugin
|
||||
parser.add_argument(
|
||||
"--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4."
|
||||
)
|
||||
# hybrid plugin
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
|
||||
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
|
||||
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
|
||||
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_layernorm_kernel",
|
||||
action="store_true",
|
||||
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||
)
|
||||
|
||||
# loss
|
||||
parser.add_argument(
|
||||
"--router_aux_loss_factor",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Moe router z loss. You can refer to STMoE for details.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--router_z_loss_factor",
|
||||
type=float,
|
||||
default=0.0001,
|
||||
help="Moe router aux loss. You can refer to STMoE for details.",
|
||||
)
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.")
|
||||
parser.add_argument(
|
||||
"--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor."
|
||||
)
|
||||
|
||||
# load balance
|
||||
parser.add_argument(
|
||||
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
|
||||
)
|
||||
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
|
||||
# communicate overlap
|
||||
parser.add_argument(
|
||||
"--comm_overlap",
|
||||
action="store_true",
|
||||
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
test_mode = args.model_name == "test"
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
hybrid_dict = {
|
||||
"tp_size": 1,
|
||||
"custom_policy": OpenMoeForCausalLMPolicy(),
|
||||
"enable_fused_normalization": args.use_layernorm_kernel,
|
||||
"enable_jit_fused": args.use_kernel,
|
||||
"precision": args.precision,
|
||||
"zero_stage": args.zero_stage,
|
||||
}
|
||||
mgr_dict = {
|
||||
"seed": 42,
|
||||
}
|
||||
if args.plugin == "ep":
|
||||
dp_size = dist.get_world_size()
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
elif args.plugin == "ep_zero":
|
||||
dp_size = dist.get_world_size()
|
||||
use_ep_inside = False
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
extra_dp_size=args.extra_dp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size // args.extra_dp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**mgr_dict,
|
||||
)
|
||||
elif args.plugin == "hybrid":
|
||||
dp_size = dist.get_world_size() // args.pp_size
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=args.pp_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=args.dp_size,
|
||||
fixed_ep_size=args.ep_size,
|
||||
fixed_pp_size=args.pp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build OpenMoe model
|
||||
if test_mode:
|
||||
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
|
||||
config.hidden_size = 128
|
||||
config.intermediate_size = 256
|
||||
config.vocab_size = 32000
|
||||
else:
|
||||
repo_name = "hpcaitech/openmoe-" + args.model_name
|
||||
config = LlamaConfig.from_pretrained(repo_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
router_aux_loss_factor=args.router_aux_loss_factor,
|
||||
router_z_loss_factor=args.router_z_loss_factor,
|
||||
z_loss_factor=args.z_loss_factor,
|
||||
enable_load_balance=args.load_balance,
|
||||
enable_comm_overlap=args.comm_overlap,
|
||||
enable_kernel=args.use_kernel,
|
||||
)
|
||||
with skip_init():
|
||||
model = OpenMoeForCausalLM(config)
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
if test_mode:
|
||||
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
|
||||
collate_fn = None
|
||||
else:
|
||||
dataset = load_dataset(args.dataset, args.task_name)
|
||||
dataset = dataset["train"]
|
||||
collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
if not test_mode:
|
||||
load_ckpt(repo_name, model, booster)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Start finetuning
|
||||
coordinator.print_on_master(f"Start finetuning")
|
||||
for epoch in range(args.num_epoch):
|
||||
model.train()
|
||||
train_dataloader_iter = iter(dataloader)
|
||||
total_len = len(train_dataloader_iter)
|
||||
with tqdm(
|
||||
range(total_len),
|
||||
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
|
||||
disable=not coordinator.is_master(),
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
# Forward pass
|
||||
outputs = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
# Forward pass
|
||||
data = next(train_dataloader_iter)
|
||||
data = move_to_cuda(data, torch.cuda.current_device())
|
||||
outputs = model(**data)
|
||||
loss = outputs["loss"]
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Apply load balance
|
||||
if (
|
||||
args.load_balance
|
||||
and args.load_balance_interval > 0
|
||||
and (step + 1) % args.load_balance_interval == 0
|
||||
):
|
||||
coordinator.print_on_master(f"Apply load balance")
|
||||
apply_load_balance(model, optimizer)
|
||||
# save ckeckpoint
|
||||
if (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
booster.save_model(model, args.output_path, shard=True)
|
||||
|
||||
# save checkpoint at the end of each epochs
|
||||
booster.save_model(model, args.output_path, shard=True)
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
|
||||
# Finish training
|
||||
coordinator.print_on_master(f"Finish training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
NUM_GPU=8
|
||||
MODEL="8b"
|
||||
SEQ_LENGTH=2048
|
||||
BATCH_SIZE=1
|
||||
LR=0.00001
|
||||
|
||||
# ep zero
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name $MODEL \
|
||||
--plugin "ep_zero" \
|
||||
--batch_size $BATCH_SIZE \
|
||||
--lr $LR \
|
||||
--zero_stage 1 \
|
||||
--extra_dp_size 2
|
||||
|
||||
# ep
|
||||
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name $MODEL \
|
||||
# --plugin "ep_zero" \
|
||||
# --batch_size $BATCH_SIZE \
|
||||
# --lr $LR \
|
||||
# --zero_stage 1
|
||||
|
||||
# hybrid
|
||||
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name $MODEL \
|
||||
# --plugin "hybrid" \
|
||||
# --batch_size $BATCH_SIZE \
|
||||
# --lr $LR \
|
||||
# --zero_stage 1 \
|
||||
# --pp_size 2 \
|
||||
# --dp_size 1 \
|
||||
# --ep_size 2 \
|
|
@ -2,4 +2,4 @@
|
|||
markers =
|
||||
dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs)
|
||||
largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs)
|
||||
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy
|
||||
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_fx --ignore=tests/test_legacy
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
|
||||
try:
|
||||
import triton
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LEN = 16
|
||||
HIDDEN_SIZE = 32
|
||||
|
||||
|
||||
def SwiGLU(x):
|
||||
"""Gated linear unit activation function.
|
||||
Args:
|
||||
x : input array
|
||||
axis: the axis along which the split should be computed (default: -1)
|
||||
"""
|
||||
size = x.shape[-1]
|
||||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||||
x1, x2 = torch.split(x, size // 2, -1)
|
||||
return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
|
||||
def test_llama_act_combine(dtype: str):
|
||||
x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda()
|
||||
x_gate_torch = nn.Parameter(x_gate.detach().clone())
|
||||
x_gate_kernel = nn.Parameter(x_gate.detach().clone())
|
||||
x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda()
|
||||
x_up_torch = nn.Parameter(x_up.detach().clone())
|
||||
x_up_kernel = nn.Parameter(x_up.detach().clone())
|
||||
|
||||
torch_out = SwiGLU(x_gate_torch) * x_up_torch
|
||||
kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel)
|
||||
atol = 1e-5 if dtype == torch.float32 else 5e-2
|
||||
assert torch.allclose(torch_out, kernel_out, atol=atol)
|
||||
|
||||
torch_out.mean().backward()
|
||||
kernel_out.mean().backward()
|
||||
assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad])
|
||||
assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol)
|
||||
assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_llama_act_combine(torch.float16)
|
|
@ -0,0 +1,169 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
|
||||
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
|
||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_moe_epsize_param_dict
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
|
||||
|
||||
|
||||
class MoeModel(nn.Module):
|
||||
def __init__(self, enable_load_balance: bool = False):
|
||||
class TestSubModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.moe = SparseMLP(
|
||||
num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance
|
||||
)
|
||||
self.proj = nn.Linear(16, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.moe(x)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
super().__init__()
|
||||
self.test_embed = nn.Linear(4, 16)
|
||||
self.test_transform = TestSubModule()
|
||||
|
||||
def forward(self, x):
|
||||
MOE_MANAGER.reset_loss()
|
||||
|
||||
x = self.test_embed(x)
|
||||
x = self.test_transform(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class MoeGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group and
|
||||
moe model parallel. A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer=None):
|
||||
super().__init__(model, optimizer)
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running an all-reduce operation in a data parallel group.
|
||||
Then running an all-reduce operation for all parameters in experts
|
||||
across moe model parallel group
|
||||
"""
|
||||
if dist.get_world_size() > 1:
|
||||
epsize_param_dict = get_moe_epsize_param_dict(self._model)
|
||||
|
||||
# epsize is 1, indicating the params are replicated among processes in data parallelism
|
||||
# use the ParallelMode.DATA to get data parallel group
|
||||
# reduce gradients for all parameters in data parallelism
|
||||
if 1 in epsize_param_dict:
|
||||
bucket_allreduce(param_list=epsize_param_dict[1])
|
||||
|
||||
for ep_size in epsize_param_dict:
|
||||
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
|
||||
bucket_allreduce(
|
||||
param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
|
||||
)
|
||||
|
||||
|
||||
def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
||||
"""Sync the parameters of tp model from ep model
|
||||
|
||||
Args:
|
||||
tp_model (MoeModule)
|
||||
ep_model (MoeModule)
|
||||
"""
|
||||
for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
|
||||
assert tp_name == ep_name
|
||||
if not is_moe_tensor(tp_param):
|
||||
if assert_grad_flag:
|
||||
assert torch.allclose(tp_param, ep_param)
|
||||
assert torch.allclose(tp_param.grad, ep_param.grad)
|
||||
else:
|
||||
tp_param.data.copy_(ep_param.data)
|
||||
continue
|
||||
|
||||
# gather param from ep model
|
||||
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
||||
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
|
||||
all_param = torch.cat(param_list, dim=0)
|
||||
if assert_grad_flag:
|
||||
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
||||
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
|
||||
all_grad = torch.cat(grad_list, dim=0)
|
||||
|
||||
# get tp param
|
||||
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2]
|
||||
tp_rank = get_ep_rank(tp_param)
|
||||
tp_dim = tp_dim[0] + 1
|
||||
tp_slice = [slice(None)] * tp_dim + [
|
||||
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
|
||||
]
|
||||
new_tp_param = all_param[tuple(tp_slice)]
|
||||
if assert_grad_flag:
|
||||
new_grad = all_grad[tuple(tp_slice)]
|
||||
if assert_grad_flag:
|
||||
assert torch.allclose(tp_param, new_tp_param)
|
||||
assert torch.allclose(tp_param.grad, new_grad)
|
||||
else:
|
||||
tp_param.data.copy_(new_tp_param.data)
|
||||
|
||||
|
||||
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
||||
"""Sync the parameters of tp model from ep model
|
||||
|
||||
Args:
|
||||
local_model (MoeModule)
|
||||
ep_model (MoeModule)
|
||||
"""
|
||||
for (local_name, local_param), (ep_name, ep_param) in zip(
|
||||
local_model.named_parameters(), ep_model.named_parameters()
|
||||
):
|
||||
assert local_name == ep_name
|
||||
if "experts" not in local_name:
|
||||
if assert_grad_flag:
|
||||
assert torch.allclose(local_param, ep_param)
|
||||
assert torch.allclose(local_param.grad, ep_param.grad)
|
||||
else:
|
||||
local_param.data.copy_(ep_param.data)
|
||||
continue
|
||||
|
||||
# gather param from ep model
|
||||
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
||||
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
|
||||
all_param = torch.cat(param_list, dim=0)
|
||||
if assert_grad_flag:
|
||||
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
||||
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
|
||||
all_grad = torch.cat(grad_list, dim=0)
|
||||
|
||||
if assert_grad_flag:
|
||||
assert torch.allclose(local_param, all_param)
|
||||
assert torch.allclose(local_param.grad, all_grad)
|
||||
else:
|
||||
local_param.data.copy_(all_param.data)
|
||||
|
||||
|
||||
def assert_not_equal_in_group(tensor, process_group=None):
|
||||
# all gather tensors from different ranks
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
|
||||
# check if they are equal one by one
|
||||
for i in range(world_size - 1):
|
||||
a = tensor_list[i]
|
||||
b = tensor_list[i + 1]
|
||||
assert not torch.allclose(
|
||||
a, b
|
||||
), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"
|
|
@ -4,40 +4,58 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
|
||||
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import sync_moe_model_param
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group
|
||||
|
||||
BATCH_SIZE = 4
|
||||
DIM = 16
|
||||
CONFIG = dict()
|
||||
|
||||
|
||||
def run_test(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
expert_module = nn.Linear
|
||||
expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device())
|
||||
colossalai.launch(
|
||||
config=dict(),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host="localhost",
|
||||
port=port,
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
MOE_CONTEXT.setup(42) # MOE initialization
|
||||
noisy_func = UniformNoiseGenerator()
|
||||
router = Top1Router(noisy_func=noisy_func)
|
||||
MOE_MANAGER.setup(42, parallel="EP") # MOE initialization
|
||||
num_experts_list = [1, 2, 4]
|
||||
layer_list = []
|
||||
for num_experts in num_experts_list:
|
||||
exp = Experts(expert_module, num_experts, **expert_factor)
|
||||
moe_layer = MoeLayer(DIM, num_experts, router, exp)
|
||||
moe_layer = SparseMLP(
|
||||
hidden_size=DIM,
|
||||
intermediate_size=DIM * 4,
|
||||
num_experts=num_experts,
|
||||
router_top_k=1,
|
||||
router_noisy_policy="Jitter",
|
||||
)
|
||||
layer_list.append(moe_layer)
|
||||
|
||||
model = nn.ModuleList(layer_list)
|
||||
model = model.to(get_current_device())
|
||||
dist_dict = MOE_MANAGER.parallel_info_dict
|
||||
assert_not_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
|
||||
assert_not_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
|
||||
assert_not_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)
|
||||
assert_not_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group)
|
||||
assert_not_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group)
|
||||
assert_not_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group)
|
||||
|
||||
sync_moe_model_param(model)
|
||||
|
||||
dist_dict = MOE_CONTEXT.parallel_info_dict
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group)
|
||||
assert_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group)
|
||||
# MoE model synchronization passed
|
||||
|
||||
grad_handler = MoeGradientHandler(model, 0)
|
||||
|
@ -47,17 +65,18 @@ def run_test(rank, world_size, port):
|
|||
data = torch.randn(BATCH_SIZE, DIM, device=get_current_device())
|
||||
grad = torch.randn_like(data)
|
||||
|
||||
MOE_CONTEXT.reset_loss()
|
||||
MOE_MANAGER.reset_loss()
|
||||
for layer in layer_list:
|
||||
data, _ = layer(data)
|
||||
data = layer(data)
|
||||
data.backward(grad)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group)
|
||||
|
||||
assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(layer_list[0].experts.wi.grad, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(layer_list[0].experts.wo.grad, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(layer_list[1].experts.wi.grad, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(layer_list[1].experts.wo.grad, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(layer_list[2].experts.wi.grad, dist_dict[4].dp_group)
|
||||
assert_equal_in_group(layer_list[2].experts.wo.grad, dist_dict[4].dp_group)
|
||||
# MoE grad handler test passed
|
||||
|
||||
|
||||
|
|
|
@ -1,49 +1,47 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
BATCH_SIZE = 16
|
||||
BATCH_SIZE = 4
|
||||
NUM_EXPERTS = 4
|
||||
CONFIG = dict()
|
||||
|
||||
|
||||
def check_equal(tensor_a, tensor_b, atol=1e-06):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
|
||||
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router):
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1):
|
||||
# Here we do not need TF32, since it brings absolute error on results
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
MOE_CONTEXT.setup(42) # MOE environment initialization
|
||||
MOE_CONTEXT.reset_loss()
|
||||
torch.manual_seed(rs + local_rank) # set each process has different random seed
|
||||
MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization
|
||||
MOE_MANAGER.reset_loss()
|
||||
torch.manual_seed(rs + local_rank) # set each process has different random seed
|
||||
|
||||
# get randomized data
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
|
||||
expert_module = nn.Linear
|
||||
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
|
||||
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
|
||||
layer = SparseMLP(hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_experts=NUM_EXPERTS,
|
||||
router_top_k=topk,
|
||||
router_capacity_factor_train=1.0)
|
||||
layer = layer.to(get_current_device())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
|
||||
# use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine
|
||||
layer.use_kernel = False
|
||||
old_out, _ = layer(tokens)
|
||||
layer.enable_kernel = False
|
||||
old_out = layer(tokens)
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad) # get gradient
|
||||
|
@ -56,8 +54,8 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
|||
tokens.grad.zero_()
|
||||
layer.gate_weight.grad.zero_()
|
||||
|
||||
layer.use_kernel = True
|
||||
new_out, _ = layer(tokens) # get outputs through colossal kernel
|
||||
layer.enable_kernel = True
|
||||
new_out = layer(tokens) # get outputs through colossal kernel
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
|
@ -86,11 +84,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
|||
@pytest.mark.parametrize("rs", [131])
|
||||
@pytest.mark.parametrize("hidden_size", [32, 144])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize("router", [Top1Router, Top2Router])
|
||||
@pytest.mark.parametrize("topk", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_kernel(rs, hidden_size, data_type, router):
|
||||
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router)
|
||||
def test_moe_kernel(rs, hidden_size, data_type, topk):
|
||||
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_kernel(2, 256, torch.float16, Top2Router)
|
||||
if __name__ == '__main__':
|
||||
test_moe_kernel(2, 256, torch.float16, 2)
|
||||
|
|
|
@ -1,50 +1,138 @@
|
|||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.nn.layer.moe import load_moe_model, save_moe_model
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_zero.test_legacy.common import CONFIG
|
||||
|
||||
sys.path.append(os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"examples/language/openmoe",
|
||||
))
|
||||
|
||||
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
|
||||
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
|
||||
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
|
||||
|
||||
|
||||
def exam_moe_checkpoint():
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = MoeModel(checkpoint=True)
|
||||
save_moe_model(model, "temp_path.pth")
|
||||
def get_config():
|
||||
config = LlamaConfig(
|
||||
vocab_size=300,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=2,
|
||||
head_dim=4,
|
||||
dropout_rate=0.0,
|
||||
hidden_act="swiglu",
|
||||
)
|
||||
set_openmoe_args(config, num_experts=16, moe_layer_interval=1)
|
||||
return config
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
other_model = MoeModel(checkpoint=True)
|
||||
load_moe_model(other_model, "temp_path.pth")
|
||||
|
||||
state_0 = model.state_dict()
|
||||
state_1 = other_model.state_dict()
|
||||
for k, v in state_0.items():
|
||||
u = state_1.get(k)
|
||||
def get_model(parallel):
|
||||
config = get_config()
|
||||
model = OpenMoeForCausalLM(config)
|
||||
|
||||
if parallel == None:
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
elif parallel == "zero_ep":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
zero_stage=1,
|
||||
microbatch_size=1,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, _, _, _, _ = booster.boost(model=model)
|
||||
return model, booster
|
||||
|
||||
|
||||
def _test_moe_checkpoint(parallel, shard):
|
||||
if parallel == None:
|
||||
MOE_MANAGER.setup(
|
||||
seed=42,
|
||||
parallel=None,
|
||||
)
|
||||
elif parallel == "zero2_ep":
|
||||
MOE_MANAGER.setup(
|
||||
seed=42,
|
||||
parallel="EP",
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
MOE_MANAGER.setup(
|
||||
seed=42,
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=1,
|
||||
fixed_ep_size=2,
|
||||
fixed_pp_size=2,
|
||||
)
|
||||
model1, booster1 = get_model(parallel)
|
||||
model2, booster2 = get_model(parallel)
|
||||
|
||||
if shard:
|
||||
booster1.save_model(model1, "./tmp_ckpt", shard=True, size_per_shard=1)
|
||||
booster2.load_model(model2, "./tmp_ckpt")
|
||||
else:
|
||||
booster1.save_model(model1, "tmp_ckpt.pth")
|
||||
booster2.load_model(model2, "tmp_ckpt.pth")
|
||||
|
||||
state1 = model1.state_dict()
|
||||
state2 = model2.state_dict()
|
||||
for k, v in state1.items():
|
||||
u = state2.get(k)
|
||||
assert torch.equal(u.data, v.data)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
os.remove("temp_path.pth")
|
||||
if shard:
|
||||
shutil.rmtree("./tmp_ckpt")
|
||||
else:
|
||||
os.remove("tmp_ckpt.pth")
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
exam_moe_checkpoint()
|
||||
def _run_dist(rank, world_size, port, parallel, shard):
|
||||
colossalai.launch(
|
||||
config=dict(),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host="localhost",
|
||||
port=port,
|
||||
backend="nccl",
|
||||
)
|
||||
_test_moe_checkpoint(parallel, shard)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("parallel", [None, "zero_ep", "hybrid"])
|
||||
@pytest.mark.parametrize("shard", [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_checkpoint(world_size):
|
||||
spawn(_run_dist)
|
||||
def test_moe_checkpoint(world_size, parallel, shard):
|
||||
spawn(_run_dist, world_size, parallel=parallel, shard=shard)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_checkpoint(world_size=4)
|
||||
test_moe_checkpoint(world_size=4, parallel="hybrid", shard=True)
|
||||
|
|
|
@ -1,55 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_zero.test_legacy.common import CONFIG
|
||||
|
||||
|
||||
@parameterize("init_device_type", ["cpu", "cuda"])
|
||||
def exam_moe_colo_init(init_device_type):
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if init_device_type == "cuda":
|
||||
init_device = get_current_device()
|
||||
elif init_device_type == "cpu":
|
||||
init_device = torch.device("cpu")
|
||||
else:
|
||||
raise NotImplementedError("Unknown device found.")
|
||||
|
||||
with ColoInitContext(device=init_device):
|
||||
model = MoeModel(checkpoint=True)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name)
|
||||
|
||||
if hasattr(param, "moe_info"):
|
||||
param.set_process_group(param.moe_info.pg)
|
||||
|
||||
if hasattr(param, "moe_info"):
|
||||
assert param.process_group.dp_world_size() == param.moe_info.dp_size
|
||||
else:
|
||||
assert param.process_group.dp_world_size() == world_size
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
exam_moe_colo_init()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_colo_init(world_size):
|
||||
spawn(_run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_colo_init(world_size=4)
|
|
@ -0,0 +1,81 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import sync_moe_model_param
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep
|
||||
|
||||
|
||||
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
|
||||
assert batch_size % world_size == 0
|
||||
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed, parallel=None)
|
||||
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed, parallel="EP")
|
||||
ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed, parallel="TP")
|
||||
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||
ep_model = ep_model.to(get_current_device())
|
||||
tp_model = tp_model.to(get_current_device())
|
||||
local_model = local_model.to(get_current_device())
|
||||
|
||||
# sync ep param
|
||||
sync_moe_model_param(ep_model)
|
||||
dist_dict = MOE_MANAGER.parallel_info_dict
|
||||
assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
|
||||
assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
|
||||
grad_handler = MoeGradientHandler(ep_model)
|
||||
# sync tp param
|
||||
sync_tp_from_ep(tp_model, ep_model)
|
||||
# sync local param
|
||||
sync_local_from_ep(local_model, ep_model)
|
||||
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.manual_seed(seed)
|
||||
tp_data = torch.randn(batch_size, dim, device=get_current_device())
|
||||
micro_batch_size = batch_size // world_size
|
||||
ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)]
|
||||
|
||||
out_local = local_model(tp_data)
|
||||
MOE_MANAGER.reset_loss()
|
||||
out_tp = tp_model(tp_data)
|
||||
MOE_MANAGER.reset_loss()
|
||||
out_ep = ep_model(ep_data)
|
||||
MOE_MANAGER.reset_loss()
|
||||
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)])
|
||||
assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)])
|
||||
|
||||
out_local.mean().backward()
|
||||
out_tp.mean().backward()
|
||||
out_ep.mean().backward()
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
|
||||
assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
|
||||
|
||||
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
|
||||
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_experts", [4, 8])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("dim", [32])
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
|
||||
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42)
|
|
@ -3,66 +3,80 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.nn.layer.moe import Experts
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import sync_moe_model_param
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
|
||||
D_MODEL = 4
|
||||
D_FF = 8
|
||||
CONFIG = dict()
|
||||
HIDDEN_SIZE = 4
|
||||
INTERMEDIATE_SIZE = 8
|
||||
|
||||
|
||||
def run_test(rank, world_size, port):
|
||||
world_size = 4
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
expert_module = nn.Linear
|
||||
expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())
|
||||
def run_moe_init(expert_parallel):
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed=42, parallel=expert_parallel)
|
||||
expert_args = dict(
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
expert_parallel=expert_parallel,
|
||||
)
|
||||
exp0 = MLPExperts(1, **expert_args)
|
||||
exp1 = MLPExperts(2, **expert_args)
|
||||
exp2 = MLPExperts(4, **expert_args)
|
||||
|
||||
MOE_CONTEXT.setup(42) # MOE environment initialization
|
||||
exp0 = Experts(expert_module, 1, **expert_factor)
|
||||
exp1 = Experts(expert_module, 2, **expert_factor)
|
||||
exp2 = Experts(expert_module, 4, **expert_factor)
|
||||
exp3 = Experts(expert_module, 8, **expert_factor)
|
||||
if expert_parallel == "EP":
|
||||
assert exp0.num_local_experts == 1
|
||||
assert exp1.num_local_experts == 1
|
||||
assert exp2.num_local_experts == 2
|
||||
else:
|
||||
assert exp0.num_local_experts == 1
|
||||
assert exp1.num_local_experts == 2
|
||||
assert exp2.num_local_experts == 4
|
||||
|
||||
assert exp0.num_local_experts == 1
|
||||
assert exp1.num_local_experts == 1
|
||||
assert exp2.num_local_experts == 1
|
||||
assert exp3.num_local_experts == 2
|
||||
# experts deployment passed
|
||||
|
||||
parallel_info_dict = MOE_CONTEXT.parallel_info_dict
|
||||
parallel_info_dict = MOE_MANAGER.parallel_info_dict
|
||||
rank = dist.get_rank()
|
||||
|
||||
assert len(parallel_info_dict) == 3
|
||||
assert dist.get_rank(parallel_info_dict[4].ep_group) == rank
|
||||
# group creation assert
|
||||
assert len(parallel_info_dict) == 2
|
||||
assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2
|
||||
assert dist.get_rank(parallel_info_dict[1].ep_group) == 0
|
||||
|
||||
assert dist.get_rank(parallel_info_dict[4].dp_group) == 0
|
||||
assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2
|
||||
assert dist.get_rank(parallel_info_dict[1].dp_group) == rank
|
||||
# group creation passed
|
||||
|
||||
model = nn.ModuleList([exp0, exp1, exp2, exp3])
|
||||
model = nn.ModuleList([exp0, exp1, exp2])
|
||||
model = model.to(get_current_device())
|
||||
sync_moe_model_param(model)
|
||||
|
||||
assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group)
|
||||
assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group)
|
||||
# MOE experts layout success when ep_size = 1
|
||||
assert_equal_in_group(exp0.wi.data, parallel_info_dict[1].dp_group)
|
||||
assert_equal_in_group(exp0.wo.data, parallel_info_dict[1].dp_group)
|
||||
|
||||
assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group)
|
||||
assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group)
|
||||
# MOE experts layout success when ep_size = 2
|
||||
assert_equal_in_group(exp1.wi.data, parallel_info_dict[2].dp_group)
|
||||
assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group)
|
||||
|
||||
|
||||
def _run_test(rank, world_size, port, expert_parallel):
|
||||
colossalai.launch(
|
||||
config=dict(),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host="localhost",
|
||||
port=port,
|
||||
backend="nccl",
|
||||
)
|
||||
run_moe_init(expert_parallel)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("expert_parallel", ["EP", "TP"])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_initialization():
|
||||
spawn(run_test, 4)
|
||||
def test_moe_initialization(expert_parallel):
|
||||
spawn(_run_test, 2, expert_parallel=expert_parallel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_initialization()
|
||||
test_moe_initialization("EP")
|
||||
test_moe_initialization("TP")
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_moe.moe_utils import MoeModel
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if isinstance(model, LowLevelZeroModel):
|
||||
optimizer.backward(loss / 2)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
|
||||
def run_zero_optim_test(local_rank, world_size, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
data = torch.randn(16, 4).cuda()
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed=42, parallel=None)
|
||||
torch_model = MoeModel()
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters())
|
||||
torch_model = torch_model.cuda()
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP")
|
||||
zero_model = MoeModel()
|
||||
extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group
|
||||
ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group)
|
||||
ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size
|
||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
if is_moe_tensor(zero_param):
|
||||
num_expert = torch_param.data.shape[0]
|
||||
zero_param.data.copy_(
|
||||
torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)]
|
||||
.detach()
|
||||
.clone()
|
||||
)
|
||||
else:
|
||||
zero_param.data.copy_(torch_param.data.detach().clone())
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
||||
plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group
|
||||
booster = Booster(plugin=plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
torch_optimizer.step()
|
||||
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
zero_optimizer.step()
|
||||
|
||||
for (torch_name, torch_param), (zero_name, zero_param) in zip(
|
||||
torch_model.named_parameters(), zero_model.named_parameters()
|
||||
):
|
||||
if is_moe_tensor(zero_param):
|
||||
num_expert = torch_param.data.shape[0]
|
||||
torch_param.data = torch_param.data[
|
||||
ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)
|
||||
]
|
||||
assert torch.allclose(
|
||||
torch_param.data, zero_param.data, atol=1e-4
|
||||
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_optim_test(rank, world_size, stage=1)
|
||||
run_zero_optim_test(rank, world_size, stage=2)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_optim(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_optim(world_size=4)
|
|
@ -0,0 +1,190 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
||||
|
||||
|
||||
def split_ddp_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if isinstance(model, LowLevelZeroModel):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
|
||||
def run_zero_optim_test(local_rank, world_size, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(
|
||||
seed=42,
|
||||
parallel="EP",
|
||||
)
|
||||
zero_model = MoeModel(enable_load_balance=True)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True)
|
||||
booster = Booster(plugin=plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed=42, parallel="EP")
|
||||
torch_model = MoeModel()
|
||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
torch_param.data.copy_(zero_param.data)
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters())
|
||||
torch_model = torch_model.cuda().bfloat16()
|
||||
grad_handler = MoeGradientHandler(torch_model)
|
||||
|
||||
# run to update expert load
|
||||
data = torch.randn(16, 4).cuda().bfloat16() / 1000 / (local_rank + 1)
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
|
||||
# run torch model twice
|
||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
grad_handler.handle_gradient()
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
# get optim and load status in zero model
|
||||
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
zero_optimizer.step()
|
||||
zero_optimizer.zero_grad()
|
||||
with torch.no_grad():
|
||||
origin_out = zero_model(data)
|
||||
|
||||
# load balance
|
||||
apply_load_balance(zero_model, zero_optimizer)
|
||||
|
||||
# run again to test
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
torch.allclose(origin_out, zero_out)
|
||||
|
||||
# assert optim
|
||||
torch_optimizer.step()
|
||||
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
zero_optimizer.step()
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}"
|
||||
|
||||
|
||||
def run_hybrid_zero_optim_test(local_rank, world_size, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
data = torch.randn(16, 4).cuda()
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(seed=42, parallel=None)
|
||||
torch_model = MoeModel()
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters())
|
||||
torch_model = torch_model.cuda()
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(
|
||||
seed=42,
|
||||
max_ep_size=2,
|
||||
use_ep_inside=False,
|
||||
parallel="EP",
|
||||
)
|
||||
zero_model = MoeModel(enable_load_balance=True)
|
||||
extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group
|
||||
ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group)
|
||||
ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size
|
||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
if is_moe_tensor(zero_param):
|
||||
num_expert = torch_param.data.shape[0]
|
||||
zero_param.data.copy_(
|
||||
torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)]
|
||||
.detach()
|
||||
.clone()
|
||||
)
|
||||
else:
|
||||
zero_param.data.copy_(torch_param.data.detach().clone())
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
||||
plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group
|
||||
booster = Booster(plugin=plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
# run torch for twice
|
||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
torch_optimizer.step()
|
||||
|
||||
# run zero
|
||||
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
zero_optimizer.step()
|
||||
zero_optimizer.zero_grad()
|
||||
with torch.no_grad():
|
||||
origin_out = zero_model(data)
|
||||
|
||||
# load balance
|
||||
apply_load_balance(zero_model, zero_optimizer)
|
||||
|
||||
# assert out
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
torch.allclose(origin_out, zero_out)
|
||||
|
||||
# assert optim
|
||||
zero_optimizer.step()
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
# TODO: high atol, check if bug exists
|
||||
assert torch.allclose(zero_out, torch_out, atol=8e-4), f"zero_out:{zero_out}\ntorch_out{torch_out}"
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(
|
||||
config=dict(),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host="localhost",
|
||||
port=port,
|
||||
backend="nccl",
|
||||
)
|
||||
run_zero_optim_test(rank, world_size, stage=1)
|
||||
run_zero_optim_test(rank, world_size, stage=2)
|
||||
run_hybrid_zero_optim_test(rank, world_size, stage=1)
|
||||
run_hybrid_zero_optim_test(rank, world_size, stage=2)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_load_balance(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_load_balance(world_size=4)
|
|
@ -0,0 +1,41 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["router", "num_groups"], [
|
||||
(Top1Router(), 1),
|
||||
(Top2Router(), 1),
|
||||
(TopKRouter(num_selected_experts=3), 4),
|
||||
])
|
||||
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
|
||||
(4, 5, 8),
|
||||
(3, 4, 4),
|
||||
])
|
||||
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
|
||||
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
|
||||
if num_groups > 1:
|
||||
x = x.expand(num_groups, -1, -1)
|
||||
|
||||
router.train()
|
||||
if isinstance(router, TopKRouter):
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
combine_array, dispatch_mask = router(x)
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
||||
router.eval()
|
||||
if isinstance(router, TopKRouter):
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
combine_array, dispatch_mask = router(x)
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_router_forward(Top1Router(), 4, 4, 4, 1)
|
|
@ -0,0 +1,105 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
||||
|
||||
|
||||
def split_ddp_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if isinstance(model, LowLevelZeroModel):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
|
||||
def run_zero_test(local_rank, world_size, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
zero_model = MoeModel()
|
||||
optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
||||
booster = Booster(plugin=plugin)
|
||||
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer)
|
||||
|
||||
torch_model = MoeModel()
|
||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
torch_param.data.copy_(zero_param.data)
|
||||
torch_model = torch_model.cuda()
|
||||
grad_handler = MoeGradientHandler(torch_model)
|
||||
|
||||
# assert zero model
|
||||
for (torch_name, torch_param), (zero_name, zero_param) in zip(
|
||||
torch_model.named_parameters(), zero_model.module.named_parameters()
|
||||
):
|
||||
assert zero_name == torch_name
|
||||
assert torch.allclose(zero_param.data, torch_param.data)
|
||||
|
||||
data = torch.randn(16, 4).cuda()
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
|
||||
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer)
|
||||
assert torch.allclose(torch_out, zero_out)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
for (zero_name, zero_param), (torch_name, torch_param) in zip(
|
||||
zero_model.module.named_parameters(), torch_model.named_parameters()
|
||||
):
|
||||
assert zero_name == torch_name
|
||||
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
|
||||
if hasattr(zero_param, "moe_info"):
|
||||
assert len(zero_grad_list) == 0
|
||||
assert torch.allclose(zero_param.grad, torch_param.grad)
|
||||
else:
|
||||
assert len(zero_grad_list) > 0
|
||||
torch_grad_list = split_ddp_grad(torch_param.grad, world_size)
|
||||
if stage == 2:
|
||||
torch_grad_list = torch_grad_list[local_rank : local_rank + 1]
|
||||
assert len(zero_grad_list) == len(torch_grad_list)
|
||||
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
|
||||
assert torch.allclose(zero_grad, torch_grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_MANAGER.setup(seed=42, parallel="EP")
|
||||
seed_all(42 + rank)
|
||||
run_zero_test(rank, world_size, stage=1)
|
||||
run_zero_test(rank, world_size, stage=2)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_model(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_model(world_size=2)
|
|
@ -1,106 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CheckpointModule
|
||||
from colossalai.nn.layer import MoeModule
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from tests.test_zero.test_legacy.common import CONFIG
|
||||
|
||||
|
||||
class MoeModel(nn.Module):
|
||||
def __init__(self, checkpoint: bool = False):
|
||||
class TestSubModule(CheckpointModule):
|
||||
def __init__(self):
|
||||
super().__init__(checkpoint)
|
||||
expert_cls = nn.Linear
|
||||
expert_args_dict = dict(in_features=16, out_features=16)
|
||||
self.moe = MoeModule(
|
||||
dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict
|
||||
)
|
||||
self.proj = nn.Linear(16, 4)
|
||||
|
||||
def _forward(self, x):
|
||||
x, y = self.moe(x)
|
||||
x = self.proj(x)
|
||||
return x, y
|
||||
|
||||
super().__init__()
|
||||
self.test_embed = nn.Linear(4, 16)
|
||||
self.test_transform = TestSubModule()
|
||||
|
||||
def forward(self, x):
|
||||
MOE_CONTEXT.reset_loss()
|
||||
|
||||
x = self.test_embed(x)
|
||||
x, y = self.test_transform(x)
|
||||
|
||||
MOE_CONTEXT.add_loss(y)
|
||||
return x
|
||||
|
||||
|
||||
@parameterize("init_device_type", ["cpu", "cuda"])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_moe_zero_init(init_device_type, shard_strategy_class):
|
||||
get_dist_logger("test_moe_zero_init")
|
||||
|
||||
if init_device_type == "cuda":
|
||||
init_device = get_current_device()
|
||||
elif init_device_type == "cpu":
|
||||
init_device = torch.device("cpu")
|
||||
else:
|
||||
raise NotImplementedError("Unknown device found.")
|
||||
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
with ZeroInitContext(
|
||||
target_device=init_device,
|
||||
shard_strategy=shard_strategy_class(),
|
||||
shard_param=True,
|
||||
model_numel_tensor=model_numel_tensor,
|
||||
):
|
||||
model = MoeModel(checkpoint=True)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert hasattr(param, "colo_attr")
|
||||
|
||||
# the parameters in moe experts and its gate should not be sharded
|
||||
if ("experts" in name) or ("gate" in name) or ("residual_combine" in name):
|
||||
assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
|
||||
else:
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
|
||||
# the parameters in moe experts is not replicated
|
||||
if "experts" in name:
|
||||
assert not param.colo_attr.is_replicated
|
||||
else:
|
||||
assert param.colo_attr.is_replicated
|
||||
|
||||
if param.colo_attr.param_is_sharded:
|
||||
assert (
|
||||
param.colo_attr.data_payload.device.type == init_device.type
|
||||
), f"{param.colo_attr.data_payload.device.type} vs. {init_device.type}"
|
||||
else:
|
||||
assert param.colo_attr.data_payload.device.type == "cuda"
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
run_moe_zero_init()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_init(world_size):
|
||||
spawn(_run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_init(world_size=2)
|
|
@ -1,70 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
|
||||
from colossalai.nn import MoeLoss
|
||||
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
|
||||
|
||||
@parameterize("enable_autocast", [False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, shard_strategy_class):
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model")
|
||||
_, train_dataloader, _, optimizer_class, _ = get_components_func()
|
||||
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
|
||||
|
||||
with ZeroInitContext(
|
||||
target_device=torch.device("cuda", torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True
|
||||
):
|
||||
zero_model = MoeModel(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
|
||||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.data_payload)
|
||||
|
||||
model = MoeModel(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda()
|
||||
grad_handler = MoeGradientHandler(model)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, enable_autocast)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
run_model_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_model(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_model(world_size=2)
|
|
@ -2,120 +2,91 @@ import pytest
|
|||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.legacy.amp import convert_to_apex_amp
|
||||
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
|
||||
from colossalai.nn import MoeLoss
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
|
||||
from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
||||
|
||||
|
||||
def _run_step(model, optimizer, data, label, criterion, grad_handler):
|
||||
def split_ddp_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
if isinstance(model, LowLevelZeroModel):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
if grad_handler is not None:
|
||||
|
||||
def run_zero_optim_test(local_rank, world_size, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
zero_model = MoeModel()
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
||||
booster = Booster(plugin=plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
torch_model = MoeModel()
|
||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
torch_param.data.copy_(zero_param.data)
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters())
|
||||
torch_model = torch_model.cuda()
|
||||
grad_handler = MoeGradientHandler(torch_model)
|
||||
|
||||
for _ in range(2):
|
||||
data = torch.randn(16, 4).cuda() / (local_rank + 1)
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
optimizer.step()
|
||||
torch_optimizer.step()
|
||||
zero_optimizer.step()
|
||||
|
||||
for (torch_name, torch_param), (zero_name, zero_param) in zip(
|
||||
torch_model.named_parameters(), zero_model.named_parameters()
|
||||
):
|
||||
assert torch.allclose(
|
||||
torch_param.data, zero_param.data
|
||||
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
|
||||
|
||||
torch_optimizer.zero_grad()
|
||||
zero_optimizer.zero_grad()
|
||||
|
||||
|
||||
@parameterize("cpu_offload", [True])
|
||||
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
|
||||
@parameterize("reuse_fp16_shard", [True, False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def _run_test_sharded_optim_v2(
|
||||
cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0
|
||||
):
|
||||
shard_strategy = shard_strategy_class()
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
return
|
||||
MOE_CONTEXT.reset_loss()
|
||||
get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model")
|
||||
_, train_dataloader, _, optimizer_class, _ = get_components_func()
|
||||
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
|
||||
|
||||
with ZeroInitContext(
|
||||
target_device=torch.device("cpu") if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
):
|
||||
zero_model = MoeModel(checkpoint=True)
|
||||
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
tensor_placement_policy="cpu" if cpu_offload else "cuda",
|
||||
reuse_fp16_shard=reuse_fp16_shard,
|
||||
)
|
||||
|
||||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device()))
|
||||
|
||||
model = MoeModel(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda().float()
|
||||
|
||||
if use_cpuadam:
|
||||
optimizer_class = CPUAdam
|
||||
optim = optimizer_class(model.parameters(), lr=1e-3)
|
||||
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(
|
||||
zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio
|
||||
)
|
||||
|
||||
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False)
|
||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||
apex_grad_handler = MoeGradientHandler(model)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
data, label = data.cuda(), label.cuda()
|
||||
_run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler)
|
||||
_run_step(zero_model, sharded_optim, data, label, criterion, None)
|
||||
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
|
||||
for param in model.parameters():
|
||||
assert not has_inf_or_nan(param)
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_MANAGER.setup(seed=42, parallel="EP")
|
||||
run_zero_optim_test(rank, world_size, stage=1)
|
||||
run_zero_optim_test(rank, world_size, stage=2)
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
_run_test_sharded_optim_v2()
|
||||
|
||||
|
||||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_optim(world_size):
|
||||
spawn(_run_dist, world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_optim(world_size=4)
|
||||
test_moe_zero_optim(world_size=2)
|
||||
|
|
Loading…
Reference in New Issue