[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark
pull/6057/head
botbw 3 months ago committed by GitHub
parent 8fd25d6e09
commit c54c4fcd15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -64,13 +64,18 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False, overlap_allgather: bool = False,
): ):
pg_param_list = { if dp_process_group is moe_dp_group:
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), pg_param_list = {
moe_dp_group: list(filter(is_moe_tensor, model.parameters())), dp_process_group: list(model.parameters()),
} }
else:
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0: if len(pg_param_list[moe_dp_group]) == 0:
raise ValueError("No parameters found in dp_process_group or moe_dp_group") raise ValueError("No parameters found in moe_dp_group, please consider using HybridParallelPlugin instead")
super().__init__( super().__init__(
model=model, model=model,
@ -407,6 +412,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
and self.enable_sequence_parallelism and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all" and self.sequence_parallelism_mode == "all_to_all"
) )
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group
if use_ddp: if use_ddp:
self.logger.warning( self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated", f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
@ -414,17 +426,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
self.ddp_config["find_unused_parameters"] = True self.ddp_config["find_unused_parameters"] = True
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError( raise ValueError(
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0" f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
) )
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule( model = HybridParallelModule(
module=model, module=model,
precision=self.precision, precision=self.precision,
@ -466,6 +472,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
) )
else: else:
is_zero = True
if self.dp_size <= 1: if self.dp_size <= 1:
self.logger.warning( self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "

@ -308,7 +308,7 @@ class EPGradScalerIn(torch.autograd.Function):
assert len(grad_outputs) == 1 assert len(grad_outputs) == 1
grad = grad_outputs[0] grad = grad_outputs[0]
if ctx.ep_size != 1: if ctx.ep_size != 1:
grad = grad * ctx.ep_size grad.mul_(ctx.ep_size)
return grad, None return grad, None
@ -328,7 +328,7 @@ class EPGradScalerOut(torch.autograd.Function):
assert len(grad_outputs) == 1 assert len(grad_outputs) == 1
grad = grad_outputs[0] grad = grad_outputs[0]
if ctx.ep_size != 1: if ctx.ep_size != 1:
grad = grad / ctx.ep_size grad.div_(ctx.ep_size)
return grad, None return grad, None
@ -449,7 +449,4 @@ def all_to_all_uneven(
overlap: bool = False, overlap: bool = False,
fp8_communication: bool = False, fp8_communication: bool = False,
): ):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)

@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.functional as F
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
@ -28,11 +28,13 @@ from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer._operation import (
all_to_all_comm, all_to_all_comm,
gather_forward_split_backward, gather_forward_split_backward,
linear_with_async_comm,
split_forward_gather_backward, split_forward_gather_backward,
) )
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
@ -58,7 +60,7 @@ class AddAuxiliaryLoss(torch.autograd.Function):
return grad_output, grad_loss return grad_output, grad_loss
class EPDeepseekMoE(nn.Module): class EPDeepseekMoE(ParallelModule):
def __init__(self): def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
@ -214,6 +216,79 @@ class EPDeepseekMoE(nn.Module):
return output_hidden_states return output_hidden_states
class DeepseekMoEGate_Col(ParallelModule):
def parallel_linear(self, hidden_states):
assert (
hidden_states.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
hidden_states.shape, self.weight.shape, self.weight.shape[-1]
)
output = linear_with_async_comm(
hidden_states, self.weight, None, self.process_group, True, fp8_communication=self.fp8_communication
)
# All-gather across the partitions.
output = gather_forward_split_backward(
output, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
return output
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = self.parallel_linear(hidden_states)
if self.scoring_func == "softmax":
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
@staticmethod
def from_native_module(
module, process_group: ProcessGroup, config, gather_output, fp8_communication
) -> "DeepseekMoEGate_Col":
LazyInitContext.materialize(module)
module.process_group = process_group
module.fp8_communication = fp8_communication
sharded_weight = shard_rowwise(module.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, module.weight)
module.__class__ = DeepseekMoEGate_Col
return module
class DeepseekPipelineForwards: class DeepseekPipelineForwards:
""" """
This class serves as a micro library for forward function substitution of Llama models This class serves as a micro library for forward function substitution of Llama models

@ -36,7 +36,7 @@ from colossalai.shardformer.layer._operation import (
gather_forward_split_backward, gather_forward_split_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
@ -49,7 +49,7 @@ if is_flash_attn_2_available():
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): class EPMixtralSparseMoeBlock(ParallelModule):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")

@ -10,6 +10,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.deepseek import ( from colossalai.shardformer.modeling.deepseek import (
DeepseekMoEGate_Col,
DeepseekPipelineForwards, DeepseekPipelineForwards,
EPDeepseekMoE, EPDeepseekMoE,
get_deepseek_flash_attention_forward, get_deepseek_flash_attention_forward,
@ -56,16 +57,24 @@ class DeepseekPolicy(Policy):
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size
# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
num_q_heads //= sp_size
decoder_attribute_replacement = { decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size, "num_heads": num_q_heads,
} }
if getattr(self.model.config, "num_key_value_heads", False): if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
) )
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
@ -97,6 +106,7 @@ class DeepseekPolicy(Policy):
else: else:
if self.tie_weight: if self.tie_weight:
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# tensor parallelism for non-moe params # tensor parallelism for non-moe params
assert ( assert (
@ -107,10 +117,15 @@ class DeepseekPolicy(Policy):
), f"The number of key_value heads must be divisible by tensor parallel size." ), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = { decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
} }
num_q_heads //= tp_size
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": num_q_heads,
}
if num_kv_heads:
num_kv_heads //= tp_size
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
policy["DeepseekDecoderLayer"] = ModulePolicyDescription( policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -135,8 +150,19 @@ class DeepseekPolicy(Policy):
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription(
suffix="mlp.gate",
target_module=DeepseekMoEGate_Col,
kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"config": self.model.config,
},
ignore_if_not_exist=True,
),
], ],
) )
if embedding_cls is not None: if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(

@ -51,12 +51,20 @@ class MixtralPolicy(Policy):
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size
# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
num_q_heads //= sp_size
decoder_attribute_replacement = { decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size, "num_heads": num_q_heads,
} }
if getattr(self.model.config, "num_key_value_heads", False): if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -101,12 +109,14 @@ class MixtralPolicy(Policy):
assert ( assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size." ), f"The number of key_value heads must be divisible by tensor parallel size."
num_q_heads //= tp_size
decoder_attribute_replacement = { decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attn.num_heads": num_q_heads,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
} }
if num_kv_heads:
num_kv_heads //= tp_size
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
policy[MixtralDecoderLayer] = ModulePolicyDescription( policy[MixtralDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -131,7 +141,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={"fp8_communication": self.shard_config.fp8_communication},
), ),
SubModuleReplacementDescription( # or replicate? SubModuleReplacementDescription(
suffix="block_sparse_moe.gate", suffix="block_sparse_moe.gate",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},

@ -0,0 +1,271 @@
# modified from mixtral benchmark
import argparse
import resource
import time
import warnings
from contextlib import nullcontext
import torch
import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore")
# ==============================
# Constants
# ==============================
# We have lots of llamas for your choice!
MODEL_CONFIGS = {
"100m": lambda: AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096,
num_hidden_layers=1,
num_attention_heads=32,
intermediate_size=512,
moe_intermediate_size=128,
hidden_size=512,
n_routed_experts=8,
n_shared_experts=4,
num_experts_per_tok=2,
first_k_dense_replace=0,
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
"7b": lambda: AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096,
num_hidden_layers=13,
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
"14b": lambda: AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096,
num_hidden_layers=26,
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
}
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration")
parser.add_argument(
"-p",
"--plugin",
choices=["3d"],
default="3d",
help="Choose which plugin to use",
)
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
parser.add_argument(
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
)
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--ep", type=int, default=1, help="Expert parallel size")
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
"--nsys",
action="store_true",
help="Use nsys for profiling. \
You should put something like this before colossalai launch: \
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
)
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument(
"--sp_mode",
default="all_to_all",
choices=["all_to_all"],
help="Sequence parallelism mode",
)
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
args = parser.parse_args()
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ckpt config for LLaMA3-70B on 64 H100 GPUs
hybrid_kwargs = (
{
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved",
}
if args.custom_ckpt
else {}
)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "3d":
plugin = MoeHybridParallelPlugin(
ep_size=args.ep,
tp_size=args.tp,
pp_size=args.pp,
pp_style=args.pp_style,
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
**hybrid_kwargs,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
# ==============================
# Initialize Dataset and Dataloader
# ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
config = MODEL_CONFIGS[args.config]()
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
# ==============================
# Initialize Model and Optimizer
# ==============================
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, MoeHybridParallelPlugin)
else nullcontext()
)
with init_ctx:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
performance_evaluator = PerformanceEvaluator(
model_numel,
model.config.num_hidden_layers,
model.config.hidden_size,
model.config.vocab_size,
args.grad_checkpoint,
args.ignore_steps,
dp_world_size=dp_size,
)
optimizer = HybridAdam(model.parameters())
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
)
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
)
with get_profile_context(
args.profile,
args.ignore_steps,
1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys,
) as prof: # , distributed_debug_mode(10, enable=True):
if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
print(f"rank {dist.get_rank()} step {step} passed")
else:
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
del outputs # free memory
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
prof.step()
performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
main()

@ -105,7 +105,7 @@ def main():
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_fp8", action="store_true") parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument( parser.add_argument(
"--sp_mode", "--sp_mode",
@ -151,6 +151,7 @@ def main():
max_prefetch=args.prefetch_num, max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce, enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -164,6 +165,7 @@ def main():
enable_async_reduce=not args.disable_async_reduce, enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "fsdp": elif args.plugin == "fsdp":
if use_empty_init: if use_empty_init:
@ -224,6 +226,7 @@ def main():
enable_metadata_cache=not args.no_cache, enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
**hybrid_kwargs, **hybrid_kwargs,
) )
elif args.plugin == "3d_cpu": elif args.plugin == "3d_cpu":
@ -241,6 +244,7 @@ def main():
precision="bf16", precision="bf16",
overlap_p2p=args.overlap, overlap_p2p=args.overlap,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")

@ -0,0 +1,259 @@
# modified from llama benchmark
import argparse
import resource
import time
import warnings
from contextlib import nullcontext
import torch
import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context
from tqdm import tqdm
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore")
# ==============================
# Constants
# ==============================
# We have lots of llamas for your choice!
MODEL_CONFIGS = {
"100m": MixtralConfig(
max_position_embeddings=4096,
num_hidden_layers=4,
num_attention_heads=32,
intermediate_size=768,
hidden_size=768,
attn_implementation="flash_attention_2",
),
"7b": MixtralConfig(
max_position_embeddings=4096,
num_hidden_layers=5,
attn_implementation="flash_attention_2",
),
"14b": MixtralConfig(
max_position_embeddings=4096,
num_hidden_layers=10,
attn_implementation="flash_attention_2",
),
}
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration")
parser.add_argument(
"-p",
"--plugin",
choices=["3d"],
default="3d",
help="Choose which plugin to use",
)
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
parser.add_argument(
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
)
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--ep", type=int, default=1, help="Expert parallel size")
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
"--nsys",
action="store_true",
help="Use nsys for profiling. \
You should put something like this before colossalai launch: \
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
)
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument(
"--sp_mode",
default="all_to_all",
choices=["all_to_all"],
help="Sequence parallelism mode",
)
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
args = parser.parse_args()
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ckpt config for LLaMA3-70B on 64 H100 GPUs
hybrid_kwargs = (
{
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved",
}
if args.custom_ckpt
else {}
)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "3d":
plugin = MoeHybridParallelPlugin(
ep_size=args.ep,
tp_size=args.tp,
pp_size=args.pp,
pp_style=args.pp_style,
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
**hybrid_kwargs,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
# ==============================
# Initialize Dataset and Dataloader
# ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = MixtralConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
# ==============================
# Initialize Model and Optimizer
# ==============================
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, MoeHybridParallelPlugin)
else nullcontext()
)
with init_ctx:
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
performance_evaluator = PerformanceEvaluator(
model_numel,
model.config.num_hidden_layers,
model.config.hidden_size,
model.config.vocab_size,
args.grad_checkpoint,
args.ignore_steps,
dp_world_size=dp_size,
)
optimizer = HybridAdam(model.parameters())
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
)
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
)
with get_profile_context(
args.profile,
args.ignore_steps,
1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys,
) as prof:
if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
else:
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
del outputs # free memory
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
prof.step()
performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
main()

@ -1,4 +1,12 @@
import os
import traceback
from contextlib import contextmanager
from time import sleep
from typing import Callable, List, Optional
import torch import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""): def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
@ -25,7 +33,66 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
return torch.allclose(a, b, rtol=rtol, atol=atol) return torch.allclose(a, b, rtol=rtol, atol=atol)
def check_model_equal(model1, model2): def check_model_equal(model1, model2, dtype):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
assert_loose_close(p1, p2, p1.dtype) assert_loose_close(p1, p2, dtype, name=name)
@contextmanager
def distributed_debug_mode(num_stacks: int = 1, funcs_to_patch: Optional[List[Callable]] = None, enable=True):
if enable:
assert (
os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1"
), f"Expect CUDA_LAUNCH_BLOCKING=1, got {os.environ.get('CUDA_LAUNCH_BLOCKING', '0')}"
if funcs_to_patch is None:
funcs_to_patch = [
dist.all_reduce,
dist.all_reduce_coalesced,
dist.all_gather,
dist.all_gather_coalesced,
dist.all_gather_into_tensor,
dist.all_to_all,
dist.all_to_all_single,
dist.reduce_scatter,
]
original_funcs = {}
patched_funcs = {}
def make_patched(func):
def patched_func(*args, **kwargs):
stack = traceback.format_stack()
def format_node(node):
if isinstance(node, torch.Tensor):
return f"{node.shape}"
elif isinstance(node, list):
return f"[{', '.join([format_node(n) for n in node])}]"
return str(node)
args_str, kwargs_str = tree_map(format_node, (args, kwargs))
en = len(stack) - 1
st = max(0, en - num_stacks)
dist.barrier()
sleep(0.001 * dist.get_rank())
print(
f"[Rank {dist.get_rank()}-{func.__name__}-{dist.get_process_group_ranks(kwargs.get('group', dist.group.WORLD))}]: Called from {''.join(stack[st:en])}args={args_str} kwargs={kwargs_str}\n"
)
dist.barrier()
return func(*args, **kwargs)
return patched_func
if enable:
for func in funcs_to_patch:
original_funcs[func.__name__] = getattr(dist, func.__name__)
patched_funcs[func.__name__] = make_patched(func)
setattr(dist, func.__name__, patched_funcs[func.__name__])
try:
yield
finally:
for func_name, original_func in original_funcs.items():
setattr(dist, func_name, original_func)

@ -130,7 +130,7 @@ def check_moe_checkpoint(test_config):
dist.barrier() dist.barrier()
if dist.get_rank() == 0: if dist.get_rank() == 0:
saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype) saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(orig_model, saved_model) check_model_equal(orig_model, saved_model, dtype=dtype)
saved_model.save_pretrained(hf_model_dir) saved_model.save_pretrained(hf_model_dir)
dist.barrier() dist.barrier()
# check load model # check load model
@ -138,7 +138,7 @@ def check_moe_checkpoint(test_config):
new_optimizer = Adam(new_model.parameters(), lr=1e-3) new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, hf_model_dir) booster.load_model(new_model, hf_model_dir)
check_model_equal(model, new_model) check_model_equal(model, new_model, dtype=dtype)
# check save optimizer # check save optimizer
optimizer.step() optimizer.step()

@ -12,43 +12,25 @@ from transformers import AutoConfig, AutoModel
import colossalai import colossalai
from colossalai.booster.booster import Booster from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8 NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2 NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
NUM_LAYERS = 4 NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4 HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4 NUM_HEADS = 8
TOP_K = 2 TOP_K = 2
CHECKED_CONFIG = [ # FOR_WORLD=4 def run_deepseek_commom(config: Tuple[int, ...]):
(1, 4, 1, 1, 1), Randomizer.reset_index()
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, pp_size, tp_size, sp_size = config stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
dtype, precision = torch.float16, "fp16" dtype, precision = torch.bfloat16, "bf16"
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
@ -60,11 +42,11 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage, zero_stage=stage,
enable_sequence_parallelism=sp_size > 1, enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
enable_flash_attention=sp_size > 1,
overlap_communication=False, overlap_communication=False,
initial_scale=1, initial_scale=1,
precision=precision, precision=precision,
find_unused_parameters=True, find_unused_parameters=True,
enable_flash_attention=True,
) )
dp_size = plugin.dp_size dp_size = plugin.dp_size
@ -171,7 +153,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier() dist.barrier()
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda() saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
check_model_equal(torch_model, saved_model) check_model_equal(torch_model, saved_model, dtype=dtype)
dist.barrier() dist.barrier()
if rank == world_size - 1: if rank == world_size - 1:
@ -180,17 +162,77 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
print(f"rank {dist.get_rank()} test passed") print(f"rank {dist.get_rank()} test passed")
def run_dist(rank, world_size, port): @parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 2, 2, 1),
# zero 1
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 2, 1, 1, 2),
# zero 2
(2, 4, 1, 1, 1),
(2, 1, 4, 1, 1),
(2, 1, 1, 4, 1),
(2, 2, 1, 1, 2),
],
)
def run_deepseek_test(config: Tuple[int, ...]):
run_deepseek_commom(config)
@parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 2, 4, 1),
(0, 1, 4, 2, 1),
(0, 1, 1, 4, 1),
(0, 1, 4, 1, 1),
# zero 1:
(1, 2, 1, 1, 2),
(1, 2, 1, 4, 1),
(1, 1, 1, 2, 2),
(1, 2, 2, 2, 1),
# zero 2
(2, 2, 1, 1, 2),
(2, 2, 1, 4, 1),
(2, 1, 1, 2, 2),
(2, 2, 2, 2, 1),
],
)
def run_deepseek_3d_test(config: Tuple[int, ...]):
run_deepseek_commom(config)
def check_deepseek(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_deepseek_test()
def check_deepseek_3d(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model() run_deepseek_3d_test()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_deepseek(world_size): def test_deepseek(world_size):
spawn(run_dist, world_size) spawn(check_deepseek, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_deepseek_3d(world_size):
spawn(check_deepseek_3d, world_size)
if __name__ == "__main__": if __name__ == "__main__":
test_deepseek(world_size=4) test_deepseek(world_size=8)
test_deepseek_3d(world_size=8)

@ -13,42 +13,25 @@ from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai import colossalai
from colossalai.booster.booster import Booster from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8 NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
NUM_LAYERS = 4 NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4 HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4 NUM_HEADS = 8
TOP_K = 1 TOP_K = 2
CHECKED_CONFIG = [ # FOR WORLD=4
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
def run_mixtral_commom(config: Tuple[int, ...]):
@parameterize( Randomizer.reset_index()
"config",
[
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, pp_size, tp_size, sp_size = config stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
dtype, precision = torch.float16, "fp16" dtype, precision = torch.bfloat16, "bf16"
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
@ -165,7 +148,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier() dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype) saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model) check_model_equal(torch_model, saved_model, dtype=dtype)
dist.barrier() dist.barrier()
if rank == world_size - 1: if rank == world_size - 1:
@ -174,17 +157,78 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
print(f"rank {dist.get_rank()} test passed") print(f"rank {dist.get_rank()} test passed")
def run_dist(rank, world_size, port): @parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 2, 2, 1),
# zero 1
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 2, 1, 1, 2),
# zero 2
(2, 4, 1, 1, 1),
(2, 1, 4, 1, 1),
(2, 1, 1, 4, 1),
(2, 2, 1, 1, 2),
],
)
def run_mixtral_test(config: Tuple[int, ...]):
run_mixtral_commom(config)
@parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 2, 4, 1),
(0, 1, 4, 2, 1),
(0, 1, 1, 4, 1),
(0, 1, 4, 1, 1),
# zero 1:
(1, 2, 1, 1, 2),
(1, 2, 1, 4, 1),
(1, 1, 1, 2, 2),
(1, 2, 2, 2, 1),
# zero 2
(2, 2, 1, 1, 2),
(2, 2, 1, 4, 1),
(2, 1, 1, 2, 2),
(2, 2, 2, 2, 1),
],
)
def run_mixtral_3d_test(config: Tuple[int, ...]):
print(f"{config=}")
run_mixtral_commom(config)
def check_mixtral(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mixtral_test()
def check_mixtral_3d(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model() run_mixtral_3d_test()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_mixtral(world_size): def test_mixtral(world_size):
spawn(run_dist, world_size) spawn(check_mixtral, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_mixtral_3d(world_size):
spawn(check_mixtral_3d, world_size)
if __name__ == "__main__": if __name__ == "__main__":
test_mixtral(world_size=4) test_mixtral(world_size=8)
test_mixtral_3d(world_size=8)

Loading…
Cancel
Save