mirror of https://github.com/hpcaitech/ColossalAI
[moe] refactor mesh assignment
parent
034020bd04
commit
cb01c0d5ce
|
@ -1,9 +1,8 @@
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
from typing import Callable, List, Optional, OrderedDict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -13,6 +12,8 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
PRECISION_TORCH_TYPE,
|
||||
SUPPORT_SP_MODE,
|
||||
HybridParallelAMPOptimizer,
|
||||
HybridParallelModule,
|
||||
HybridParallelNaiveOptimizer,
|
||||
|
@ -22,9 +23,16 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
reinitialize_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import cast_to_distributed
|
||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
||||
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
|
||||
|
@ -57,7 +65,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
forced_dtype: Optional[torch.dtype] = None,
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
|
||||
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result."
|
||||
if not force_overlap_comm and (overlap_communication or partition_grad):
|
||||
raise RuntimeError(
|
||||
WARN_STR
|
||||
|
@ -105,130 +113,219 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
|
||||
class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
"""
|
||||
TODO: add docstring
|
||||
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
|
||||
Extra Args:
|
||||
ep_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
force_overlap_comm (bool): For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training. This flag is used to force overlap_communication=True.
|
||||
"""
|
||||
|
||||
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
|
||||
if "overlap_communication" not in kwargs:
|
||||
kwargs["overlap_communication"] = False # default by true in super class
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if ep_size <= 1:
|
||||
raise ValueError("Use HybridParallelPlugin when ep_size <= 1")
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
sp_size: int = None,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
sequence_parallelism_mode: str = None,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
ddp_bucket_cap_mb: int = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
force_overlap_comm: bool = False,
|
||||
) -> None:
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
if enable_sequence_parallelism:
|
||||
self.sequence_parallelism_mode = (
|
||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||
)
|
||||
assert (
|
||||
self.sequence_parallelism_mode in SUPPORT_SP_MODE
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
|
||||
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
assert (
|
||||
tp_size > 1
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
|
||||
if sp_size != 1:
|
||||
warnings.warn(
|
||||
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
|
||||
)
|
||||
self.sp_size = 1
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||
self.sp_size = 1 if sp_size is None else sp_size
|
||||
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
|
||||
else:
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
assert (
|
||||
sp_size == 1 or sp_size is None
|
||||
), f"You should not set sp_size when sequence parallelism is not enabled."
|
||||
self.sp_size = 1
|
||||
|
||||
assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}"
|
||||
self.moe_dp_size = self.dp_size // ep_size
|
||||
self.ep_size = ep_size
|
||||
self.moe_tp_size = moe_tp_size
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
self.enable_all_optimization = enable_all_optimization
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
self.moe_dp_axis, self.ep_axis = 0, 1
|
||||
self.moe_pg_mesh = ProcessGroupMesh(
|
||||
self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
|
||||
)
|
||||
else:
|
||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||
self.moe_dp_axis, self.ep_axis = 1, 2
|
||||
self.moe_pg_mesh = ProcessGroupMesh(
|
||||
self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
|
||||
)
|
||||
|
||||
self._init_moe_param_comm()
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert (
|
||||
self.zero_stage <= 1
|
||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_layers_per_stage=num_layers_per_stage,
|
||||
)
|
||||
|
||||
self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1
|
||||
and self.pp_size == 1
|
||||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
if pp_style == "interleaved":
|
||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||
self.schedule = InterleavedSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
||||
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
sequence_parallel_process_group=self.sp_group,
|
||||
ep_group=self.ep_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
|
||||
if self.use_ddp:
|
||||
warnings.warn(
|
||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
||||
)
|
||||
self.ddp_config["find_unused_parameters"] = True
|
||||
self.ddp_config = dict(
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
bucket_cap_mb=ddp_bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
|
||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||
# TODO it might make sense to support non-moe with tp on but moe with tp off
|
||||
raise ValueError(
|
||||
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin or set zero_stage > 0"
|
||||
)
|
||||
|
||||
# set param group in shard config
|
||||
self.shard_config.ep_group = self.ep_group
|
||||
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||
self.zero_config = dict(
|
||||
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
|
||||
self.max_norm = max_norm
|
||||
self.force_overlap_comm = force_overlap_comm
|
||||
|
||||
def _init_moe_param_comm(self):
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if self.enable_sequence_parallelism:
|
||||
if self.sequence_parallelism_mode == "all_to_all":
|
||||
# if sequence parallelism is enabled, ep_group reuses sp_group
|
||||
if self.ep_size != self.sp_size:
|
||||
raise ValueError(
|
||||
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
|
||||
)
|
||||
|
||||
# since we are reusing sp_group, moe_dp_group will be derived as dp_group
|
||||
self.moe_dp_size = self.dp_size
|
||||
self.moe_dp_group = self.dp_group
|
||||
self.dp_sp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
self.ep_group = self.sp_group
|
||||
self.moe_tp_group = self.tp_group
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
|
||||
)
|
||||
|
||||
else:
|
||||
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
|
||||
|
||||
if self.moe_dp_size * self.pp_size * self.ep_size * self.moe_tp_size != world_size:
|
||||
raise ValueError(
|
||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||
)
|
||||
|
||||
self.moe_dp_group = None
|
||||
self.ep_group = None
|
||||
self.moe_tp_group = None
|
||||
self.dp_sp_group = self.dp_group
|
||||
|
||||
# create submesh for ep, moe_dp, moe_tp
|
||||
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
||||
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
|
||||
)
|
||||
|
||||
global_rank = self.pg_mesh.rank
|
||||
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
|
||||
|
||||
# create groups from submesh
|
||||
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
||||
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
|
||||
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||
|
||||
# hardcode here since we only have 3 axis
|
||||
# moe_dp_group
|
||||
for ep_idx in range(self.ep_size):
|
||||
for moe_tp_idx in range(self.moe_tp_size):
|
||||
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
|
||||
group = dist.new_group(moe_dp_ranks)
|
||||
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
|
||||
assert self.moe_dp_group is None
|
||||
self.moe_dp_group = group
|
||||
# ep_group
|
||||
for moe_dp_idx in range(self.moe_dp_size):
|
||||
for moe_tp_idx in range(self.moe_tp_size):
|
||||
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
|
||||
group = dist.new_group(ep_ranks)
|
||||
if pp_rank == stage_idx and global_rank in ep_ranks:
|
||||
assert self.ep_group is None
|
||||
self.ep_group = group
|
||||
# moe_tp_group
|
||||
for moe_dp_idx in range(self.moe_dp_size):
|
||||
for ep_idx in range(self.ep_size):
|
||||
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
|
||||
group = dist.new_group(moe_tp_ranks)
|
||||
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
|
||||
assert self.moe_tp_group is None
|
||||
self.moe_tp_group = group
|
||||
|
||||
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
|
||||
# NOTE: different tp settings between moe and non moe param are complex to handle
|
||||
# we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size
|
||||
raise NotImplementedError(
|
||||
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
return MoECheckpointIO(
|
||||
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
|
@ -249,14 +346,37 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
optimizer = cast_to_distributed(optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1
|
||||
and self.pp_size == 1
|
||||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
if use_ddp:
|
||||
warnings.warn(
|
||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
||||
)
|
||||
self.ddp_config["find_unused_parameters"] = True
|
||||
|
||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||
raise ValueError(
|
||||
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
|
||||
)
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
|
||||
model = HybridParallelModule(
|
||||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_sp_group,
|
||||
dp_group=dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=self.use_ddp,
|
||||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
|
@ -301,7 +421,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
force_overlap_comm=self.force_overlap_comm,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_sp_group,
|
||||
dp_process_group=dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
|
|
|
@ -61,13 +61,10 @@ class EPDeepseekMoE(nn.Module):
|
|||
def __init__(self):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(
|
||||
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
|
||||
):
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
assert moe_tp_group is not None
|
||||
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
|
@ -85,16 +82,13 @@ class EPDeepseekMoE(nn.Module):
|
|||
self.moe_dp_group = moe_dp_group
|
||||
self.moe_dp_size = moe_dp_group.size()
|
||||
|
||||
# setup global tp group
|
||||
# setup tp group
|
||||
self.tp_group = tp_group
|
||||
|
||||
# setup moe tp group
|
||||
self.moe_tp_group = moe_tp_group
|
||||
if self.moe_tp_group.size() > 1:
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
@ -105,7 +99,6 @@ class EPDeepseekMoE(nn.Module):
|
|||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_tp_group: ProcessGroup,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "EPDeepseekMoE":
|
||||
|
@ -113,7 +106,7 @@ class EPDeepseekMoE(nn.Module):
|
|||
if module.__class__.__name__ == "DeepseekMLP":
|
||||
return module
|
||||
module.__class__ = EPDeepseekMoE
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
@ -53,13 +53,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(
|
||||
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
|
||||
):
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
assert moe_tp_group is not None
|
||||
|
||||
# setup ep group
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
|
@ -81,14 +78,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
|
||||
# setup global tp group
|
||||
self.tp_group = tp_group
|
||||
|
||||
# setup moe tp group
|
||||
self.moe_tp_group = moe_tp_group
|
||||
if self.moe_tp_group.size() > 1:
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.moe_tp_group)
|
||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
|
||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
|
||||
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
|
||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
|
||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
@ -99,14 +93,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_tp_group: ProcessGroup,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "EPMixtralSparseMoeBlock":
|
||||
# TODO: better init
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
@ -154,7 +154,6 @@ class DeepseekPolicy(Policy):
|
|||
"ep_group": self.shard_config.ep_group,
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"moe_tp_group": self.shard_config.moe_tp_group,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
|
@ -155,7 +155,6 @@ class MixtralPolicy(Policy):
|
|||
"ep_group": self.shard_config.ep_group,
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"moe_tp_group": self.shard_config.moe_tp_group,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
|
@ -50,7 +50,6 @@ class ShardConfig:
|
|||
# for moe related
|
||||
moe_dp_group: Optional[ProcessGroup] = None
|
||||
ep_group: Optional[ProcessGroup] = None
|
||||
moe_tp_group: Optional[ProcessGroup] = None
|
||||
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
|
|
|
@ -47,7 +47,6 @@ def check_deepseek_moe_layer():
|
|||
model,
|
||||
ep_group=plugin.ep_group,
|
||||
moe_dp_group=plugin.moe_dp_group,
|
||||
moe_tp_group=plugin.moe_tp_group,
|
||||
tp_group=plugin.tp_group,
|
||||
)
|
||||
ep_output = model(x)
|
||||
|
|
|
@ -42,7 +42,6 @@ def check_mixtral_moe_layer():
|
|||
ep_group=plugin.ep_group,
|
||||
tp_group=plugin.tp_group,
|
||||
moe_dp_group=plugin.moe_dp_group,
|
||||
moe_tp_group=plugin.moe_tp_group,
|
||||
)
|
||||
ep_output, ep_logits = model(x)
|
||||
assert_close(orig_logits, ep_logits)
|
||||
|
|
|
@ -24,24 +24,28 @@ NUM_HEADS = 4
|
|||
TOP_K = 2
|
||||
|
||||
|
||||
CHECKED_CONFIG = [ # FOR_WORLD=8
|
||||
(2, 1, 1, 4, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
(4, 1, 1, 1, 1),
|
||||
(2, 1, 2, 1, 1),
|
||||
CHECKED_CONFIG = [ # FOR_WORLD=4
|
||||
(1, 4, 1, 1, 1),
|
||||
(1, 1, 4, 1, 1),
|
||||
(1, 1, 1, 4, 1),
|
||||
(1, 1, 1, 1, 4),
|
||||
(0, 1, 4, 1, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 1, 1, 4),
|
||||
(1, 2, 1, 1, 1),
|
||||
]
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(2, 1, 2, 1, 1),
|
||||
# (2, 1, 1, 2, 1),
|
||||
# (2, 1, 1, 1, 2),
|
||||
(1, 2, 2, 1, 1),
|
||||
(1, 2, 1, 2, 1),
|
||||
(1, 2, 1, 1, 2),
|
||||
],
|
||||
)
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
ep_size, stage, pp_size, tp_size, sp_size = config
|
||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
|
@ -53,7 +57,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||
tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
ep_size=ep_size,
|
||||
moe_tp_size=tp_size,
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
|
|
|
@ -25,24 +25,28 @@ NUM_HEADS = 4
|
|||
TOP_K = 1
|
||||
|
||||
CHECKED_CONFIG = [ # FOR WORLD=4
|
||||
(2, 1, 2, 2, 1),
|
||||
(2, 1, 1, 2, 1),
|
||||
(2, 1, 4, 1, 1),
|
||||
(4, 1, 1, 1, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
(4, 1, 2, 1, 1),
|
||||
(2, 1, 2, 1, 1),
|
||||
(0, 1, 4, 1, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 1, 1, 4),
|
||||
(1, 4, 1, 1, 1),
|
||||
(1, 1, 4, 1, 1),
|
||||
(1, 1, 1, 4, 1),
|
||||
(1, 1, 1, 1, 4),
|
||||
(1, 2, 1, 1, 1),
|
||||
]
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(2, 1, 1, 2, 1),
|
||||
(1, 2, 2, 1, 1),
|
||||
(1, 2, 1, 2, 1),
|
||||
(1, 2, 1, 1, 2),
|
||||
(0, 2, 1, 1, 1),
|
||||
],
|
||||
)
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
ep_size, stage, pp_size, tp_size, sp_size = config
|
||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
|
@ -54,7 +58,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||
tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
ep_size=ep_size,
|
||||
moe_tp_size=tp_size,
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
|
|
Loading…
Reference in New Issue