[shardformer] support ep for deepseek v3 (#6185)

* [feature] support ep for deepseek v3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* [shardformer] fix deepseek v3 init

* [lazy] fit lora for lazy init

* [example] support npu for deepseek v3

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6073/merge
Hongxin Liu 2025-02-11 16:10:25 +08:00 committed by GitHub
parent 17062c83b9
commit 2b415e5999
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 612 additions and 22 deletions

View File

@ -19,7 +19,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelPlugin,
HybridParallelZeroOptimizer,
get_param_info,
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
@ -468,18 +467,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
# but the optimizer is not aware of ep, so we need to update the optimizer
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
@ -489,7 +483,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = HybridParallelNaiveOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
@ -507,7 +501,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,

View File

@ -64,7 +64,10 @@ class ProcessGroupMesh:
system resources.
"""
for group in self._ranks_to_group.values():
dist.destroy_process_group(group)
try:
dist.destroy_process_group(group)
except ValueError:
pass
# Manually clear all process groups to save memory
gc.collect()

View File

@ -104,7 +104,7 @@ def _data_tolist(tensor: torch.Tensor) -> list:
return tensor.data.tolist()
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad=None) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
@ -117,13 +117,14 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the converted tensor
"""
requires_grad = target.requires_grad if requires_grad is None else requires_grad
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
tensor.__class__ = cls_to_become
if cls_to_become is Parameter:
# to fit UninitializedParameter
delattr(tensor, "_is_param")
tensor.data = target
tensor.requires_grad = target.requires_grad
tensor.requires_grad = requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor.tolist = MethodType(_data_tolist, tensor)
@ -212,9 +213,10 @@ class LazyTensor(torch.Tensor):
Returns:
torch.Tensor: The materialized tensor (self).
"""
requires_grad = self.requires_grad
target = self._materialize_data()
self.clean()
return _convert_cls(self, target)
return _convert_cls(self, target, requires_grad=requires_grad)
def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""

View File

@ -0,0 +1,277 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.shardformer.layer.linear import ParallelModule
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
class EpDeepseekV3MoE(ParallelModule):
"""
A mixed expert module containing shared experts.
"""
def __init__(self, config):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(
self,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
):
assert moe_dp_group is not None
assert ep_group is not None
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.experts_per_rank = self.num_experts_per_ep
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
# setup moe_dp group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = dist.get_world_size(moe_dp_group)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@staticmethod
def from_native_module(
module,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
*args,
**kwargs,
) -> "EpDeepseekV3MoE":
if module.__class__.__name__ != "DeepseekV3MLP":
module.__class__ = EpDeepseekV3MoE
module.setup_process_groups(moe_dp_group, ep_group)
LazyInitContext.materialize(module)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y
def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
if self.ep_size > 1:
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist()
input_split_sizes = tokens_per_ep_rank.tolist()
gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group)
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
s = 0
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
gatherd_idxs[s : s + k] = i % self.experts_per_rank
s += k
gatherd_idxs = gatherd_idxs.argsort()
sorted_tokens = gathered_tokens[gatherd_idxs]
tokens_per_expert = tokens_per_expert_post_gather
# moe-dp related code
activate_experts = tokens_per_expert_post_gather > 0
activate_experts = activate_experts.int()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
# ep related code
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size)
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
# moe-dp related code
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i])
expert_out = expert(tokens_for_this_expert)
# moe-dp related code
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i])
outputs.append(expert_out)
start_idx = end_idx
if len(outputs) > 0:
outs = torch.cat(outputs, dim=0)
else:
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}"
outs = sorted_tokens
if self.ep_size > 1:
outs = EPGradScalerOut.apply(outs, self.ep_size)
new_x = torch.empty_like(outs)
new_x[gatherd_idxs] = outs
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group)
outs = gathered_tokens
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
def deepseek_v3_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for i, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and i > 0:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@ -167,6 +167,13 @@ _POLICY_LIST = {
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
),
# DeepseekV3
"transformers_modules.modeling_deepseek.DeepseekV3Model": PolicyLocation(
file_name="deepseek_v3", class_name="DeepseekV3ModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekV3ForCausalLM": PolicyLocation(
file_name="deepseek_v3", class_name="DeepseekV3ForCausalLMPolicy"
),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"

View File

@ -0,0 +1,83 @@
from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import FusedRMSNorm
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
class DeepseekV3Policy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism"
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
def preprocess(self):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
# support gradient checkpointing
# policy["DeepseekV3Model"] = ModulePolicyDescription(method_replacement={"forward": deepseek_v3_model_forward})
if self.shard_config.expert_parallel_size > 1:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EpDeepseekV3MoE,
kwargs={
"ep_group": self.shard_config.ep_group,
"moe_dp_group": self.shard_config.moe_dp_group,
},
)
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# TODO: prevent casting to fp32
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key="DeepseekV3Model",
)
return policy
def postprocess(self):
return self.model
class DeepseekV3ModelPolicy(DeepseekV3Policy):
pass
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
pass

View File

@ -68,6 +68,10 @@ class ShardConfig:
def sequence_parallel_size(self):
return self._sequence_parallel_size
@property
def expert_parallel_size(self):
return self._expert_parallel_size
def __post_init__(self):
# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization:
@ -103,6 +107,8 @@ class ShardConfig:
else:
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
self._expert_parallel_size = dist.get_world_size(self.ep_group) if self.ep_group else 1
def _turn_on_all_optimization(self):
"""
Turn on all optimization.

View File

@ -4,11 +4,13 @@ import resource
import time
import warnings
from contextlib import nullcontext
from types import MethodType
import torch
import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from peft import LoraConfig
from performance_evaluator import PerformanceEvaluator, get_profile_context
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM
@ -29,7 +31,7 @@ warnings.filterwarnings("ignore")
# We have lots of llamas for your choice!
MODEL_CONFIGS = {
"100m": lambda: AutoConfig.from_pretrained(
"100m": AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096,
num_hidden_layers=1,
@ -44,20 +46,29 @@ MODEL_CONFIGS = {
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
"7b": lambda: AutoConfig.from_pretrained(
"7b": AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096,
num_hidden_layers=13,
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
"14b": lambda: AutoConfig.from_pretrained(
"14b": AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096,
num_hidden_layers=26,
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
"v3-6b": AutoConfig.from_pretrained(
"deepseek-ai/DeepSeek-V3",
num_hidden_layers=5,
first_k_dense_replace=2,
n_routed_experts=32,
vocab_size=8192,
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
}
@ -119,6 +130,7 @@ def main():
help="Sequence parallelism mode",
)
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
parser.add_argument("--enable_lora", action="store_true", help="Enable LoRA")
args = parser.parse_args()
colossalai.launch_from_torch()
@ -151,7 +163,7 @@ def main():
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
@ -171,7 +183,10 @@ def main():
# ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
config = MODEL_CONFIGS[args.config]()
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
@ -189,11 +204,25 @@ def main():
else nullcontext()
)
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
with init_ctx:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16
).to(torch.bfloat16)
if args.enable_lora:
booster.enable_lora(
model,
lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]),
)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if model.__class__.__name__.startswith("DeepseekV3"):
model.eval()
# enable grad for moe layers
for m in model.modules():
if m.__class__.__name__ == "DeepseekV3MoE":
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")

View File

@ -7,6 +7,7 @@ from torch import Tensor
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
def divide(x: float, y: float) -> float:
@ -29,7 +30,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
# tensor = tensor / world_size
# return tensor.item()
tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float)
tensor = torch.tensor([x], device=get_current_device(), dtype=torch.float)
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()

View File

@ -5,6 +5,7 @@ from .bloom import *
from .chatglm2 import *
from .command import *
from .deepseek import *
from .deepseek_v3 import *
from .falcon import *
from .gpt import *
from .gptj import *

View File

@ -0,0 +1,87 @@
# modified from tests/kit/model_zoo/transformers/mistral.py
from types import MethodType
import torch
import transformers
from transformers import AutoConfig
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence Mixtral
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn = lambda x: x[0].mean()
loss_fn_for_lm = lambda x: x.loss
def init_deepseek():
config = AutoConfig.from_pretrained(
"deepseek-ai/DeepSeek-V3",
hidden_size=128,
intermediate_size=320,
kv_lora_rank=4,
moe_intermediate_size=32,
num_attention_heads=4,
num_experts_per_tok=4,
n_group=4,
num_hidden_layers=3,
num_key_value_heads=4,
first_k_dense_replace=1,
q_lora_rank=8,
torch_dtype="bfloat16",
n_routed_experts=16,
topk_group=2,
v_head_dim=32,
qk_nope_head_dim=32,
qk_rope_head_dim=32,
trust_remote_code=True,
vocab_size=2048,
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=True)
# enable grad for moe layers
for m in model.modules():
if m.__class__.__name__ == "DeepseekV3MoE":
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
return model
model_zoo.register(
name="transformers_deepseek_v3",
model_fn=init_deepseek,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@ -223,7 +223,6 @@ def run_forward_backward_with_hybrid_plugin(
for k, v in data.items():
unshard_test_data[k] = data[k].clone()
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in shard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
@ -248,7 +247,6 @@ def run_forward_backward_with_hybrid_plugin(
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)
org_model.train()
if booster.plugin.stage_manager is not None:
for k, v in unshard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:

View File

@ -0,0 +1,102 @@
from typing import Tuple
import pytest
import torch
import torch.distributed
import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
seed_all(42)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin
)
if enable_gradient_checkpointing:
# org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
org_model = org_model.to(torch.bfloat16)
org_model.eval()
sharded_model.eval()
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
assert_close(org_loss, sharded_loss)
param_dict = {n: p for n, p in org_model.named_parameters()}
for n, p in sharded_model.unwrap().named_parameters():
if n in param_dict:
if booster.plugin.zero_stage == 0:
grad = p.grad
target_grad = param_dict[n].grad
else:
grad = sharded_optimizer.get_working_grad_by_param_id(id(p))
pg = sharded_optimizer.param_to_pg[p]
target_grad = param_dict[n].grad
if target_grad is None:
continue
target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)]
assert_close(grad, target_grad, atol=3e-1, rtol=0)
@parameterize(
"config",
[
# zero 1
(1, 4),
(1, 2),
],
)
def run_deepseek_v3_test(config: Tuple[int, ...]):
zero_stage, ep_size = config
plugin_config = dict(
pp_size=1,
tp_size=1,
ep_size=ep_size,
zero_stage=zero_stage,
overlap_communication=False,
precision="bf16",
find_unused_parameters=True,
)
sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek_v3")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
plugin_config,
)
def check_deepseek_v3(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_deepseek_v3_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_deepseek_v3(world_size):
spawn(check_deepseek_v3, world_size)
if __name__ == "__main__":
test_deepseek_v3(world_size=4)