[shardformer] support ep for deepseek v3 (#6185)

* [feature] support ep for deepseek v3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* [shardformer] fix deepseek v3 init

* [lazy] fit lora for lazy init

* [example] support npu for deepseek v3

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6191/head
Hongxin Liu 2025-02-11 16:10:25 +08:00 committed by GitHub
parent 17062c83b9
commit 2b415e5999
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 612 additions and 22 deletions

View File

@ -19,7 +19,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelPlugin, HybridParallelPlugin,
HybridParallelZeroOptimizer, HybridParallelZeroOptimizer,
get_param_info, get_param_info,
reinitialize_optimizer,
) )
from colossalai.checkpoint_io import MoECheckpointIO from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.cluster.process_group_mesh import ProcessGroupMesh
@ -468,18 +467,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_fp8=self.use_fp8, use_fp8=self.use_fp8,
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
# but the optimizer is not aware of ep, so we need to update the optimizer
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0: if self.zero_stage == 0:
is_zero = False is_zero = False
if self.precision in ["fp16", "bf16"]: if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer( optimizer = HybridParallelAMPOptimizer(
optimizer, optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info, param_info=param_info,
precision=self.precision, precision=self.precision,
max_norm=self.max_norm, max_norm=self.max_norm,
@ -489,7 +483,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = HybridParallelNaiveOptimizer( optimizer = HybridParallelNaiveOptimizer(
optimizer, optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info, param_info=param_info,
max_norm=self.max_norm, max_norm=self.max_norm,
pp_process_group=self.pp_group, pp_process_group=self.pp_group,
@ -507,7 +501,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = MoeHybridParallelZeroOptimizer( optimizer = MoeHybridParallelZeroOptimizer(
optimizer, optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info, param_info=param_info,
dp_process_group=self.mixed_dp_group, dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,

View File

@ -64,7 +64,10 @@ class ProcessGroupMesh:
system resources. system resources.
""" """
for group in self._ranks_to_group.values(): for group in self._ranks_to_group.values():
dist.destroy_process_group(group) try:
dist.destroy_process_group(group)
except ValueError:
pass
# Manually clear all process groups to save memory # Manually clear all process groups to save memory
gc.collect() gc.collect()

View File

@ -104,7 +104,7 @@ def _data_tolist(tensor: torch.Tensor) -> list:
return tensor.data.tolist() return tensor.data.tolist()
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor: def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad=None) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data. """Convert a lazy tensor's class to target's class, with target's data.
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
@ -117,13 +117,14 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
Returns: Returns:
torch.Tensor: the converted tensor torch.Tensor: the converted tensor
""" """
requires_grad = target.requires_grad if requires_grad is None else requires_grad
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
tensor.__class__ = cls_to_become tensor.__class__ = cls_to_become
if cls_to_become is Parameter: if cls_to_become is Parameter:
# to fit UninitializedParameter # to fit UninitializedParameter
delattr(tensor, "_is_param") delattr(tensor, "_is_param")
tensor.data = target tensor.data = target
tensor.requires_grad = target.requires_grad tensor.requires_grad = requires_grad
# subclass of torch.Tensor does not have tolist() method # subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution # overwrite this method after materialization or distribution
tensor.tolist = MethodType(_data_tolist, tensor) tensor.tolist = MethodType(_data_tolist, tensor)
@ -212,9 +213,10 @@ class LazyTensor(torch.Tensor):
Returns: Returns:
torch.Tensor: The materialized tensor (self). torch.Tensor: The materialized tensor (self).
""" """
requires_grad = self.requires_grad
target = self._materialize_data() target = self._materialize_data()
self.clean() self.clean()
return _convert_cls(self, target) return _convert_cls(self, target, requires_grad=requires_grad)
def clean(self) -> None: def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""

View File

@ -0,0 +1,277 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.shardformer.layer.linear import ParallelModule
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
class EpDeepseekV3MoE(ParallelModule):
"""
A mixed expert module containing shared experts.
"""
def __init__(self, config):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(
self,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
):
assert moe_dp_group is not None
assert ep_group is not None
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.experts_per_rank = self.num_experts_per_ep
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
# setup moe_dp group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = dist.get_world_size(moe_dp_group)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@staticmethod
def from_native_module(
module,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
*args,
**kwargs,
) -> "EpDeepseekV3MoE":
if module.__class__.__name__ != "DeepseekV3MLP":
module.__class__ = EpDeepseekV3MoE
module.setup_process_groups(moe_dp_group, ep_group)
LazyInitContext.materialize(module)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y
def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
if self.ep_size > 1:
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist()
input_split_sizes = tokens_per_ep_rank.tolist()
gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group)
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
s = 0
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
gatherd_idxs[s : s + k] = i % self.experts_per_rank
s += k
gatherd_idxs = gatherd_idxs.argsort()
sorted_tokens = gathered_tokens[gatherd_idxs]
tokens_per_expert = tokens_per_expert_post_gather
# moe-dp related code
activate_experts = tokens_per_expert_post_gather > 0
activate_experts = activate_experts.int()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
# ep related code
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size)
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
# moe-dp related code
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i])
expert_out = expert(tokens_for_this_expert)
# moe-dp related code
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i])
outputs.append(expert_out)
start_idx = end_idx
if len(outputs) > 0:
outs = torch.cat(outputs, dim=0)
else:
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}"
outs = sorted_tokens
if self.ep_size > 1:
outs = EPGradScalerOut.apply(outs, self.ep_size)
new_x = torch.empty_like(outs)
new_x[gatherd_idxs] = outs
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group)
outs = gathered_tokens
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
def deepseek_v3_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for i, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and i > 0:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@ -167,6 +167,13 @@ _POLICY_LIST = {
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation( "transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
file_name="deepseek", class_name="DeepseekForCausalLMPolicy" file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
), ),
# DeepseekV3
"transformers_modules.modeling_deepseek.DeepseekV3Model": PolicyLocation(
file_name="deepseek_v3", class_name="DeepseekV3ModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekV3ForCausalLM": PolicyLocation(
file_name="deepseek_v3", class_name="DeepseekV3ForCausalLMPolicy"
),
# Falcon # Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation( "transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy" file_name="falcon", class_name="FalconModelPolicy"

View File

@ -0,0 +1,83 @@
from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import FusedRMSNorm
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
class DeepseekV3Policy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism"
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
def preprocess(self):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
# support gradient checkpointing
# policy["DeepseekV3Model"] = ModulePolicyDescription(method_replacement={"forward": deepseek_v3_model_forward})
if self.shard_config.expert_parallel_size > 1:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EpDeepseekV3MoE,
kwargs={
"ep_group": self.shard_config.ep_group,
"moe_dp_group": self.shard_config.moe_dp_group,
},
)
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# TODO: prevent casting to fp32
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key="DeepseekV3Model",
)
return policy
def postprocess(self):
return self.model
class DeepseekV3ModelPolicy(DeepseekV3Policy):
pass
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
pass

View File

@ -68,6 +68,10 @@ class ShardConfig:
def sequence_parallel_size(self): def sequence_parallel_size(self):
return self._sequence_parallel_size return self._sequence_parallel_size
@property
def expert_parallel_size(self):
return self._expert_parallel_size
def __post_init__(self): def __post_init__(self):
# turn on all optimization if all_optimization is set to True # turn on all optimization if all_optimization is set to True
if self.enable_all_optimization: if self.enable_all_optimization:
@ -103,6 +107,8 @@ class ShardConfig:
else: else:
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
self._expert_parallel_size = dist.get_world_size(self.ep_group) if self.ep_group else 1
def _turn_on_all_optimization(self): def _turn_on_all_optimization(self):
""" """
Turn on all optimization. Turn on all optimization.

View File

@ -4,11 +4,13 @@ import resource
import time import time
import warnings import warnings
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from data_utils import RandomDataset from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel from model_utils import format_numel_str, get_model_numel
from peft import LoraConfig
from performance_evaluator import PerformanceEvaluator, get_profile_context from performance_evaluator import PerformanceEvaluator, get_profile_context
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
@ -29,7 +31,7 @@ warnings.filterwarnings("ignore")
# We have lots of llamas for your choice! # We have lots of llamas for your choice!
MODEL_CONFIGS = { MODEL_CONFIGS = {
"100m": lambda: AutoConfig.from_pretrained( "100m": AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base", "deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096, max_position_embeddings=4096,
num_hidden_layers=1, num_hidden_layers=1,
@ -44,20 +46,29 @@ MODEL_CONFIGS = {
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
trust_remote_code=True, trust_remote_code=True,
), ),
"7b": lambda: AutoConfig.from_pretrained( "7b": AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base", "deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096, max_position_embeddings=4096,
num_hidden_layers=13, num_hidden_layers=13,
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
trust_remote_code=True, trust_remote_code=True,
), ),
"14b": lambda: AutoConfig.from_pretrained( "14b": AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base", "deepseek-ai/deepseek-moe-16b-base",
max_position_embeddings=4096, max_position_embeddings=4096,
num_hidden_layers=26, num_hidden_layers=26,
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
trust_remote_code=True, trust_remote_code=True,
), ),
"v3-6b": AutoConfig.from_pretrained(
"deepseek-ai/DeepSeek-V3",
num_hidden_layers=5,
first_k_dense_replace=2,
n_routed_experts=32,
vocab_size=8192,
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
} }
@ -119,6 +130,7 @@ def main():
help="Sequence parallelism mode", help="Sequence parallelism mode",
) )
parser.add_argument("--debug", action="store_true", help="Enable debug mode") parser.add_argument("--debug", action="store_true", help="Enable debug mode")
parser.add_argument("--enable_lora", action="store_true", help="Enable LoRA")
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -151,7 +163,7 @@ def main():
sp_size=args.sp, sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode, sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1, enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
precision="bf16", precision="bf16",
@ -171,7 +183,10 @@ def main():
# ============================== # ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size) dp_size = getattr(plugin, "dp_size", coordinator.world_size)
config = MODEL_CONFIGS[args.config]() if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
@ -189,11 +204,25 @@ def main():
else nullcontext() else nullcontext()
) )
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
with init_ctx: with init_ctx:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16) model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16
).to(torch.bfloat16)
if args.enable_lora:
booster.enable_lora(
model,
lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]),
)
if args.grad_checkpoint: if args.grad_checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
if model.__class__.__name__.startswith("DeepseekV3"):
model.eval()
# enable grad for moe layers
for m in model.modules():
if m.__class__.__name__ == "DeepseekV3MoE":
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
model_numel = get_model_numel(model) model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")

View File

@ -7,6 +7,7 @@ from torch import Tensor
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
def divide(x: float, y: float) -> float: def divide(x: float, y: float) -> float:
@ -29,7 +30,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
# tensor = tensor / world_size # tensor = tensor / world_size
# return tensor.item() # return tensor.item()
tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float) tensor = torch.tensor([x], device=get_current_device(), dtype=torch.float)
dist.all_reduce(tensor) dist.all_reduce(tensor)
tensor = tensor / world_size tensor = tensor / world_size
return tensor.item() return tensor.item()

View File

@ -5,6 +5,7 @@ from .bloom import *
from .chatglm2 import * from .chatglm2 import *
from .command import * from .command import *
from .deepseek import * from .deepseek import *
from .deepseek_v3 import *
from .falcon import * from .falcon import *
from .gpt import * from .gpt import *
from .gptj import * from .gptj import *

View File

@ -0,0 +1,87 @@
# modified from tests/kit/model_zoo/transformers/mistral.py
from types import MethodType
import torch
import transformers
from transformers import AutoConfig
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence Mixtral
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn = lambda x: x[0].mean()
loss_fn_for_lm = lambda x: x.loss
def init_deepseek():
config = AutoConfig.from_pretrained(
"deepseek-ai/DeepSeek-V3",
hidden_size=128,
intermediate_size=320,
kv_lora_rank=4,
moe_intermediate_size=32,
num_attention_heads=4,
num_experts_per_tok=4,
n_group=4,
num_hidden_layers=3,
num_key_value_heads=4,
first_k_dense_replace=1,
q_lora_rank=8,
torch_dtype="bfloat16",
n_routed_experts=16,
topk_group=2,
v_head_dim=32,
qk_nope_head_dim=32,
qk_rope_head_dim=32,
trust_remote_code=True,
vocab_size=2048,
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=True)
# enable grad for moe layers
for m in model.modules():
if m.__class__.__name__ == "DeepseekV3MoE":
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
return model
model_zoo.register(
name="transformers_deepseek_v3",
model_fn=init_deepseek,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@ -223,7 +223,6 @@ def run_forward_backward_with_hybrid_plugin(
for k, v in data.items(): for k, v in data.items():
unshard_test_data[k] = data[k].clone() unshard_test_data[k] = data[k].clone()
sharded_model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
for k, v in shard_test_data.items(): for k, v in shard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
@ -248,7 +247,6 @@ def run_forward_backward_with_hybrid_plugin(
sharded_loss = criterion(sharded_output) sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss) sharded_optimizer.backward(sharded_loss)
org_model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
for k, v in unshard_test_data.items(): for k, v in unshard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:

View File

@ -0,0 +1,102 @@
from typing import Tuple
import pytest
import torch
import torch.distributed
import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
seed_all(42)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin
)
if enable_gradient_checkpointing:
# org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
org_model = org_model.to(torch.bfloat16)
org_model.eval()
sharded_model.eval()
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
assert_close(org_loss, sharded_loss)
param_dict = {n: p for n, p in org_model.named_parameters()}
for n, p in sharded_model.unwrap().named_parameters():
if n in param_dict:
if booster.plugin.zero_stage == 0:
grad = p.grad
target_grad = param_dict[n].grad
else:
grad = sharded_optimizer.get_working_grad_by_param_id(id(p))
pg = sharded_optimizer.param_to_pg[p]
target_grad = param_dict[n].grad
if target_grad is None:
continue
target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)]
assert_close(grad, target_grad, atol=3e-1, rtol=0)
@parameterize(
"config",
[
# zero 1
(1, 4),
(1, 2),
],
)
def run_deepseek_v3_test(config: Tuple[int, ...]):
zero_stage, ep_size = config
plugin_config = dict(
pp_size=1,
tp_size=1,
ep_size=ep_size,
zero_stage=zero_stage,
overlap_communication=False,
precision="bf16",
find_unused_parameters=True,
)
sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek_v3")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
plugin_config,
)
def check_deepseek_v3(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_deepseek_v3_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_deepseek_v3(world_size):
spawn(check_deepseek_v3, world_size)
if __name__ == "__main__":
test_deepseek_v3(world_size=4)