[moe] implement transit between non moe tp and ep

colossalchat
botbw 2024-07-08 09:59:46 +00:00 committed by Hongxin Liu
parent 37443cc7e4
commit b5bfeb2efd
7 changed files with 234 additions and 101 deletions

View File

@ -1068,7 +1068,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}") self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0])
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None

View File

@ -30,8 +30,8 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer, optimizer: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
dp_process_group: ProcessGroup, # the dp pg for comm dp_process_group: ProcessGroup, # dp pg for comm
moe_dp_group: ProcessGroup, # the moe dp pg for gomm moe_dp_group: ProcessGroup, # moe dp pg for comm
param_info: OrderedDict, param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1, min_scale: int = 1,
@ -44,7 +44,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
verbose: bool = False, verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True, overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
@ -88,7 +88,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
TODO: add docstring TODO: add docstring
""" """
def __init__(self, ep_size: int, ep_tp_size: int = 1, *args, **kwargs) -> None: def __init__(self, ep_size: int, moe_tp_size: int = 1, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
@ -98,14 +98,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
self.ddp_config["find_unused_parameters"] = True self.ddp_config["find_unused_parameters"] = True
if ep_tp_size != 1: if moe_tp_size != 1:
raise NotImplementedError raise NotImplementedError
world_size = dist.get_world_size() world_size = dist.get_world_size()
self.moe_dp_size = world_size // (ep_size * ep_tp_size) self.moe_dp_size = world_size // (ep_size * moe_tp_size)
self.ep_size = ep_size self.ep_size = ep_size
self.moe_tp_size = ep_tp_size self.moe_tp_size = moe_tp_size
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size) self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2 self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
@ -114,7 +114,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis) self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis) self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}") self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
# set ep_group after super init # set ep_group after super init
# TODO do it in a better way # TODO do it in a better way

View File

@ -397,3 +397,106 @@ def all_to_all_uneven(
inputs.requires_grad inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program." ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
# ===========================================================
# This code section was modified from
# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# The file has been adapted from the following Megatron-LM file:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py
# Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796
# We retain the following copyright from the original files:
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def _gather_tokens(input_, dim: int, tp_group: ProcessGroup):
"""Gather tensors and concatenate them along a dimension"""
input_ = input_.contiguous()
# Size and dimension.
rank = tp_group.rank()
tensor_list = [torch.empty_like(input_) for _ in range(tp_group.size())]
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=tp_group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
def _drop_tokens(input_, dim: int, tp_group: ProcessGroup):
"""Divide a tensor among the tensor parallel ranks"""
total_chunks = tp_group.size()
this_chunk = tp_group.rank()
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
class _GatherTokens(torch.autograd.Function):
"""All gather tokens among the tensor parallel ranks"""
@staticmethod
def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor:
ctx.dim = dim
ctx.tp_group = tp_group
return _gather_tokens(input_, dim, tp_group)
@staticmethod
def backward(ctx, grad_output):
return _drop_tokens(grad_output, ctx.dim, ctx.tp_group), None, None
class _DropTokens(torch.autograd.Function):
"Divide tokens equally among the tensor parallel ranks"
@staticmethod
def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor:
ctx.dim = dim
ctx.tp_group = tp_group
return _drop_tokens(input_, dim, tp_group)
@staticmethod
def backward(ctx, input_: torch.Tensor) -> Tuple[torch.Tensor, None]:
return _gather_tokens(input_, ctx.dim, ctx.tp_group), None, None
def gather_tokens(input_, dim: int, tp_group: ProcessGroup):
if tp_group.size() == 1:
# no tensor parallelism for non-experts
return input_
assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return _GatherTokens.apply(input_, dim)
def drop_tokens(input_, dim: int, tp_group: ProcessGroup):
if tp_group.size() == 1:
# no tensor parallelism for non-experts
return input_
assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return _DropTokens.apply(input_, dim, tp_group)
# ===========================================================

View File

@ -14,21 +14,21 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.shardformer.shard.utils import set_tensors_to_none
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config, ep_group): def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
super().__init__(config) super().__init__(config)
self.setup_ep(ep_group) self.setup_process_groups(ep_group, tp_group, moe_tp_group)
def setup_ep(self, ep_group: ProcessGroup): def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
ep_group = ep_group # setup ep group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group self.ep_group = ep_group
if self.num_experts % self.ep_size != 0: if self.num_experts % self.ep_size != 0:
@ -42,13 +42,19 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
for p in self.experts.parameters(): for p in self.experts.parameters():
p.ep_group = ep_group p.ep_group = ep_group
# setup global tp group
self.tp_group = tp_group
# setup moe tp group
self.moe_tp_group = moe_tp_group
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs
) -> "EPMixtralSparseMoeBlock": ) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock module.__class__ = EPMixtralSparseMoeBlock
module.setup_ep(ep_group) module.setup_process_groups(ep_group)
return module return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -72,6 +78,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
if self.tp_group is not None and self.tp_group.size() > 1:
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output # compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size) output_states = MoeInGradScaler.apply(output_states, self.ep_size)
@ -94,6 +104,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_states = torch.cat(output_states_list) output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size) output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
if self.tp_group is not None and self.tp_group.size() > 1:
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange( recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device selected_experts_idx.size(0), device=selected_experts_idx.device

View File

@ -8,6 +8,7 @@ from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -20,15 +21,15 @@ class MixtralPolicy(Policy):
def preprocess(self): def preprocess(self):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError # non-moe params tensor parallelism
# # Resize embedding # Resize embedding
# vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
# world_size = self.shard_config.tensor_parallel_size world_size = self.shard_config.tensor_parallel_size
# if vocab_size % world_size != 0: if vocab_size % world_size != 0:
# new_vocab_size = vocab_size + world_size - vocab_size % world_size new_vocab_size = vocab_size + world_size - vocab_size % world_size
# self.model.resize_token_embeddings(new_vocab_size) self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
@ -42,74 +43,62 @@ class MixtralPolicy(Policy):
) )
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError # tensor parallelism for non-moe params
# assert ( assert (
# self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
# ), f"The number of attention heads must be divisible by tensor parallel size." ), f"The number of attention heads must be divisible by tensor parallel size."
# assert ( assert (
# self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
# ), f"The number of key_value heads must be divisible by tensor parallel size." ), f"The number of key_value heads must be divisible by tensor parallel size."
# decoder_attribute_replacement = { decoder_attribute_replacement = {
# "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
# "self_attn.num_key_value_heads": self.model.config.num_key_value_heads "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
# // self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
# } }
# policy[MixtralDecoderLayer] = ModulePolicyDescription( policy[MixtralDecoderLayer] = ModulePolicyDescription(
# attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
# sub_module_replacement=[ sub_module_replacement=[
# SubModuleReplacementDescription( SubModuleReplacementDescription(
# suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
# SubModuleReplacementDescription( # TODO: enable moe tp parallel
# suffix="mlp.gate_proj",
# target_module=Linear1D_Col, # target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ), # ),
# SubModuleReplacementDescription( # SubModuleReplacementDescription(
# suffix="self_attn.k_proj", # suffix="mlp.up_proj",
# target_module=Linear1D_Col, # target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ), # ),
# SubModuleReplacementDescription( # SubModuleReplacementDescription(
# suffix="self_attn.v_proj", # suffix="mlp.down_proj",
# target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# SubModuleReplacementDescription(
# suffix="self_attn.o_proj",
# target_module=Linear1D_Row, # target_module=Linear1D_Row,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ), # ),
# # SubModuleReplacementDescription( ],
# # suffix="mlp.gate_proj", )
# # target_module=Linear1D_Col,
# # ),
# # SubModuleReplacementDescription(
# # suffix="mlp.up_proj",
# # target_module=Linear1D_Col,
# # ),
# # SubModuleReplacementDescription(
# # suffix="mlp.down_proj",
# # target_module=Linear1D_Row,
# # ),
# ],
# )
if getattr(self.shard_config, "ep_group", None) is None: if self.shard_config.ep_group:
# expert parallel # expert parallel
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="block_sparse_moe", suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock, target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group}, kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group},
) )
], ],
policy=policy, policy=policy,

View File

@ -47,6 +47,8 @@ class ShardConfig:
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
ep_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None
moe_tp_group: Optional[ProcessGroup] = None
# pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int # data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']

View File

@ -114,39 +114,64 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"test_config", "test_config",
[ [
{ {
"tp_size": 1, "tp_size": 2,
"pp_size": 2, "pp_size": 1,
"num_microbatches": 2,
"ep_size": 1, "ep_size": 1,
"zero_stage": 0, "zero_stage": 2,
"precision": "fp32", "precision": "fp32",
}, # pp + ep }, # [dp(2) + tp(2)] + [moe_dp(4)]
{
"tp_size": 2,
"pp_size": 1,
"ep_size": 2,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(2) + tp(2)] + [ep(2) + moe_dp(2)]
{ {
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"ep_size": 1, "ep_size": 1,
"zero_stage": 0, "zero_stage": 2,
"precision": "fp32", "precision": "fp32",
}, # pp + ep }, # [dp(2) + pp(2)] + [moe_dp(4)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_dp(4)]
{ {
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"ep_size": 4, "ep_size": 4,
"zero_stage": 0, "zero_stage": 2,
"precision": "fp32", "precision": "fp32",
}, # pp + ep }, # [dp(2) + pp(2)] + [ep(4))]
{"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "bf16"}, # full dp for moe and non-moe {
{ # moe_dp = 2, non_moe_dp = 4
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 1,
"ep_size": 2, "ep_size": 2,
"zero_stage": 1, "zero_stage": 2,
"precision": "fp32", "precision": "fp32",
}, # moe_dp = 1, non_moe_dp = 4 }, # [dp(4)] + [ep(2) + moe_tp(2)]
{"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp32"}, # full dp for non-moe and full ep for moe {
{"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe "tp_size": 1,
"pp_size": 1,
"ep_size": 4,
"zero_stage": 2,
"precision": "fp32"
}, # full dp for non-moe and full ep for moe
{
"tp_size": 1,
"pp_size": 1,
"ep_size": 1,
"zero_stage": 2,
"precision": "fp32"
}, # full dp for moe and non-moe
], ],
) )
def run_mixtral_test(test_config): def run_mixtral_test(test_config):