Browse Source

[moe] refactor mesh assignment

colossalchat
hxwang 4 months ago committed by Hongxin Liu
parent
commit
cb01c0d5ce
  1. 354
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  2. 21
      colossalai/shardformer/modeling/deepseek.py
  3. 19
      colossalai/shardformer/modeling/mixtral.py
  4. 1
      colossalai/shardformer/policies/deepseek.py
  5. 1
      colossalai/shardformer/policies/mixtral.py
  6. 1
      colossalai/shardformer/shard/shard_config.py
  7. 1
      tests/test_moe/test_deepseek_layer.py
  8. 1
      tests/test_moe/test_mixtral_layer.py
  9. 23
      tests/test_shardformer/test_model/test_shard_deepseek.py
  10. 23
      tests/test_shardformer/test_model/test_shard_mixtral.py

354
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -1,9 +1,8 @@
import warnings
from collections import defaultdict
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
from typing import Callable, List, Optional, OrderedDict, Tuple
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@ -13,6 +12,8 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.booster.plugin.hybrid_parallel_plugin import (
PRECISION_TORCH_TYPE,
SUPPORT_SP_MODE,
HybridParallelAMPOptimizer,
HybridParallelModule,
HybridParallelNaiveOptimizer,
@ -22,9 +23,16 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
from colossalai.shardformer.shard.shard_config import ShardConfig
from colossalai.tensor.moe_tensor.api import is_moe_tensor
@ -57,7 +65,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result."
if not force_overlap_comm and (overlap_communication or partition_grad):
raise RuntimeError(
WARN_STR
@ -105,129 +113,218 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
TODO: add docstring
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
Extra Args:
ep_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
force_overlap_comm (bool): For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training. This flag is used to force overlap_communication=True.
"""
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
if "overlap_communication" not in kwargs:
kwargs["overlap_communication"] = False # default by true in super class
super().__init__(*args, **kwargs)
if ep_size <= 1:
raise ValueError("Use HybridParallelPlugin when ep_size <= 1")
def __init__(
self,
tp_size: int,
pp_size: int,
ep_size: int,
sp_size: int = None,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
force_overlap_comm: bool = False,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
)
assert (
self.sequence_parallelism_mode in SUPPORT_SP_MODE
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
assert (
tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1:
warnings.warn(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]:
self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
sp_size == 1 or sp_size is None
), f"You should not set sp_size when sequence parallelism is not enabled."
self.sp_size = 1
assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}"
self.moe_dp_size = self.dp_size // ep_size
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
self._init_moe_param_comm()
self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
self.tp_size = tp_size
self.pp_size = pp_size
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
self.moe_dp_axis, self.ep_axis = 0, 1
self.moe_pg_mesh = ProcessGroupMesh(
self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.moe_dp_axis, self.ep_axis = 1, 2
self.moe_pg_mesh = ProcessGroupMesh(
self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
)
if self.use_ddp:
warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert (
self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
self.ddp_config["find_unused_parameters"] = True
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
# TODO it might make sense to support non-moe with tp on but moe with tp off
raise ValueError(
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin or set zero_stage > 0"
if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
)
# set param group in shard config
self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
self.force_overlap_comm = force_overlap_comm
def _init_moe_param_comm(self):
world_size = dist.get_world_size()
if self.enable_sequence_parallelism:
if self.sequence_parallelism_mode == "all_to_all":
# if sequence parallelism is enabled, ep_group reuses sp_group
if self.ep_size != self.sp_size:
raise ValueError(
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
)
# since we are reusing sp_group, moe_dp_group will be derived as dp_group
self.moe_dp_size = self.dp_size
self.moe_dp_group = self.dp_group
self.dp_sp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.ep_group = self.sp_group
self.moe_tp_group = self.tp_group
else:
raise NotImplementedError(
f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
else:
raise NotImplementedError()
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
if self.moe_dp_size * self.pp_size * self.ep_size * self.moe_tp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
self.moe_dp_group = None
self.ep_group = None
self.moe_tp_group = None
self.dp_sp_group = self.dp_group
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
)
global_rank = self.pg_mesh.rank
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
ep_group=self.ep_group,
moe_dp_group=self.moe_dp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
)
# create groups from submesh
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
self.ddp_config = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
# hardcode here since we only have 3 axis
# moe_dp_group
for ep_idx in range(self.ep_size):
for moe_tp_idx in range(self.moe_tp_size):
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
group = dist.new_group(moe_dp_ranks)
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
assert self.moe_dp_group is None
self.moe_dp_group = group
# ep_group
for moe_dp_idx in range(self.moe_dp_size):
for moe_tp_idx in range(self.moe_tp_size):
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
group = dist.new_group(ep_ranks)
if pp_rank == stage_idx and global_rank in ep_ranks:
assert self.ep_group is None
self.ep_group = group
# moe_tp_group
for moe_dp_idx in range(self.moe_dp_size):
for ep_idx in range(self.ep_size):
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
group = dist.new_group(moe_tp_ranks)
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
assert self.moe_tp_group is None
self.moe_tp_group = group
self.zero_config = dict(
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
)
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
# NOTE: different tp settings between moe and non moe param are complex to handle
# we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size
raise NotImplementedError(
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
)
self.max_norm = max_norm
self.force_overlap_comm = force_overlap_comm
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
@ -249,14 +346,37 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = cast_to_distributed(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if use_ddp:
warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
)
self.ddp_config["find_unused_parameters"] = True
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError(
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_sp_group,
dp_group=dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=self.use_ddp,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
@ -301,7 +421,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_pipeline=self.enable_pipeline_parallelism,
force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
dp_process_group=self.dp_sp_group,
dp_process_group=dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,

21
colossalai/shardformer/modeling/deepseek.py

@ -61,13 +61,10 @@ class EPDeepseekMoE(nn.Module):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
):
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
assert moe_tp_group is not None
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
@ -85,16 +82,13 @@ class EPDeepseekMoE(nn.Module):
self.moe_dp_group = moe_dp_group
self.moe_dp_size = moe_dp_group.size()
# setup global tp group
# setup tp group
self.tp_group = tp_group
# setup moe tp group
self.moe_tp_group = moe_tp_group
if self.moe_tp_group.size() > 1:
if self.tp_group.size() > 1:
for expert in held_experts:
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@ -105,7 +99,6 @@ class EPDeepseekMoE(nn.Module):
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPDeepseekMoE":
@ -113,7 +106,7 @@ class EPDeepseekMoE(nn.Module):
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

19
colossalai/shardformer/modeling/mixtral.py

@ -53,13 +53,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
):
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
assert moe_tp_group is not None
# setup ep group
self.ep_size = dist.get_world_size(ep_group)
@ -81,14 +78,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
# setup global tp group
self.tp_group = tp_group
# setup moe tp group
self.moe_tp_group = moe_tp_group
if self.moe_tp_group.size() > 1:
if self.tp_group.size() > 1:
for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.moe_tp_group)
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@ -99,14 +93,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPMixtralSparseMoeBlock":
# TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

1
colossalai/shardformer/policies/deepseek.py

@ -154,7 +154,6 @@ class DeepseekPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"moe_tp_group": self.shard_config.moe_tp_group,
},
)
],

1
colossalai/shardformer/policies/mixtral.py

@ -155,7 +155,6 @@ class MixtralPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"moe_tp_group": self.shard_config.moe_tp_group,
},
)
],

1
colossalai/shardformer/shard/shard_config.py

@ -50,7 +50,6 @@ class ShardConfig:
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
moe_tp_group: Optional[ProcessGroup] = None
# pipeline_parallel_size: int
# data_parallel_size: int

1
tests/test_moe/test_deepseek_layer.py

@ -47,7 +47,6 @@ def check_deepseek_moe_layer():
model,
ep_group=plugin.ep_group,
moe_dp_group=plugin.moe_dp_group,
moe_tp_group=plugin.moe_tp_group,
tp_group=plugin.tp_group,
)
ep_output = model(x)

1
tests/test_moe/test_mixtral_layer.py

@ -42,7 +42,6 @@ def check_mixtral_moe_layer():
ep_group=plugin.ep_group,
tp_group=plugin.tp_group,
moe_dp_group=plugin.moe_dp_group,
moe_tp_group=plugin.moe_tp_group,
)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)

23
tests/test_shardformer/test_model/test_shard_deepseek.py

@ -24,24 +24,28 @@ NUM_HEADS = 4
TOP_K = 2
CHECKED_CONFIG = [ # FOR_WORLD=8
(2, 1, 1, 4, 1),
(4, 1, 1, 2, 1),
(4, 1, 1, 1, 1),
(2, 1, 2, 1, 1),
CHECKED_CONFIG = [ # FOR_WORLD=4
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
(2, 1, 2, 1, 1),
# (2, 1, 1, 2, 1),
# (2, 1, 1, 1, 2),
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
ep_size, stage, pp_size, tp_size, sp_size = config
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
@ -53,7 +57,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,

23
tests/test_shardformer/test_model/test_shard_mixtral.py

@ -25,24 +25,28 @@ NUM_HEADS = 4
TOP_K = 1
CHECKED_CONFIG = [ # FOR WORLD=4
(2, 1, 2, 2, 1),
(2, 1, 1, 2, 1),
(2, 1, 4, 1, 1),
(4, 1, 1, 1, 1),
(4, 1, 1, 2, 1),
(4, 1, 2, 1, 1),
(2, 1, 2, 1, 1),
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
(2, 1, 1, 2, 1),
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
(0, 2, 1, 1, 1),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
ep_size, stage, pp_size, tp_size, sp_size = config
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
@ -54,7 +58,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,

Loading…
Cancel
Save