mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support ep for deepseek v3 (#6185)
* [feature] support ep for deepseek v3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * [shardformer] fix deepseek v3 init * [lazy] fit lora for lazy init * [example] support npu for deepseek v3 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6073/merge
parent
17062c83b9
commit
2b415e5999
|
@ -19,7 +19,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
HybridParallelPlugin,
|
||||
HybridParallelZeroOptimizer,
|
||||
get_param_info,
|
||||
reinitialize_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
|
@ -468,18 +467,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.ep_size > 1:
|
||||
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
|
||||
# but the optimizer is not aware of ep, so we need to update the optimizer
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
if self.zero_stage == 0:
|
||||
is_zero = False
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
|
@ -489,7 +483,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
|
@ -507,7 +501,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
optimizer = MoeHybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.mixed_dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
|
|
|
@ -64,7 +64,10 @@ class ProcessGroupMesh:
|
|||
system resources.
|
||||
"""
|
||||
for group in self._ranks_to_group.values():
|
||||
dist.destroy_process_group(group)
|
||||
try:
|
||||
dist.destroy_process_group(group)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Manually clear all process groups to save memory
|
||||
gc.collect()
|
||||
|
|
|
@ -104,7 +104,7 @@ def _data_tolist(tensor: torch.Tensor) -> list:
|
|||
return tensor.data.tolist()
|
||||
|
||||
|
||||
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
|
||||
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad=None) -> torch.Tensor:
|
||||
"""Convert a lazy tensor's class to target's class, with target's data.
|
||||
|
||||
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
|
||||
|
@ -117,13 +117,14 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
|
|||
Returns:
|
||||
torch.Tensor: the converted tensor
|
||||
"""
|
||||
requires_grad = target.requires_grad if requires_grad is None else requires_grad
|
||||
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
|
||||
tensor.__class__ = cls_to_become
|
||||
if cls_to_become is Parameter:
|
||||
# to fit UninitializedParameter
|
||||
delattr(tensor, "_is_param")
|
||||
tensor.data = target
|
||||
tensor.requires_grad = target.requires_grad
|
||||
tensor.requires_grad = requires_grad
|
||||
# subclass of torch.Tensor does not have tolist() method
|
||||
# overwrite this method after materialization or distribution
|
||||
tensor.tolist = MethodType(_data_tolist, tensor)
|
||||
|
@ -212,9 +213,10 @@ class LazyTensor(torch.Tensor):
|
|||
Returns:
|
||||
torch.Tensor: The materialized tensor (self).
|
||||
"""
|
||||
requires_grad = self.requires_grad
|
||||
target = self._materialize_data()
|
||||
self.clean()
|
||||
return _convert_cls(self, target)
|
||||
return _convert_cls(self, target, requires_grad=requires_grad)
|
||||
|
||||
def clean(self) -> None:
|
||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import (
|
||||
DPGradScalerIn,
|
||||
DPGradScalerOut,
|
||||
EPGradScalerIn,
|
||||
EPGradScalerOut,
|
||||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.shardformer.layer.linear import ParallelModule
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
|
||||
|
||||
class EpDeepseekV3MoE(ParallelModule):
|
||||
"""
|
||||
A mixed expert module containing shared experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(
|
||||
self,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
):
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.num_experts = self.config.n_routed_experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.experts_per_rank = self.num_experts_per_ep
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
|
||||
# setup moe_dp group
|
||||
self.moe_dp_group = moe_dp_group
|
||||
self.moe_dp_size = dist.get_world_size(moe_dp_group)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "EpDeepseekV3MoE":
|
||||
if module.__class__.__name__ != "DeepseekV3MLP":
|
||||
module.__class__ = EpDeepseekV3MoE
|
||||
module.setup_process_groups(moe_dp_group, ep_group)
|
||||
LazyInitContext.materialize(module)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
if self.ep_size > 1:
|
||||
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
||||
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
|
||||
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)
|
||||
|
||||
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist()
|
||||
input_split_sizes = tokens_per_ep_rank.tolist()
|
||||
|
||||
gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group)
|
||||
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
|
||||
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
|
||||
s = 0
|
||||
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
||||
gatherd_idxs[s : s + k] = i % self.experts_per_rank
|
||||
s += k
|
||||
gatherd_idxs = gatherd_idxs.argsort()
|
||||
sorted_tokens = gathered_tokens[gatherd_idxs]
|
||||
tokens_per_expert = tokens_per_expert_post_gather
|
||||
|
||||
# moe-dp related code
|
||||
activate_experts = tokens_per_expert_post_gather > 0
|
||||
activate_experts = activate_experts.int()
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
# ep related code
|
||||
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size)
|
||||
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
# moe-dp related code
|
||||
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i])
|
||||
expert_out = expert(tokens_for_this_expert)
|
||||
# moe-dp related code
|
||||
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
if len(outputs) > 0:
|
||||
outs = torch.cat(outputs, dim=0)
|
||||
else:
|
||||
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}"
|
||||
outs = sorted_tokens
|
||||
|
||||
if self.ep_size > 1:
|
||||
outs = EPGradScalerOut.apply(outs, self.ep_size)
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[gatherd_idxs] = outs
|
||||
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group)
|
||||
outs = gathered_tokens
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
|
||||
return final_out
|
||||
|
||||
|
||||
def deepseek_v3_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_key_values_length = 0
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and i > 0:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
|
@ -167,6 +167,13 @@ _POLICY_LIST = {
|
|||
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
|
||||
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
|
||||
),
|
||||
# DeepseekV3
|
||||
"transformers_modules.modeling_deepseek.DeepseekV3Model": PolicyLocation(
|
||||
file_name="deepseek_v3", class_name="DeepseekV3ModelPolicy"
|
||||
),
|
||||
"transformers_modules.modeling_deepseek.DeepseekV3ForCausalLM": PolicyLocation(
|
||||
file_name="deepseek_v3", class_name="DeepseekV3ForCausalLMPolicy"
|
||||
),
|
||||
# Falcon
|
||||
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
|
||||
file_name="falcon", class_name="FalconModelPolicy"
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm
|
||||
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
|
||||
|
||||
|
||||
class DeepseekV3Policy(Policy):
|
||||
def config_sanity_check(self):
|
||||
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
|
||||
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism"
|
||||
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
|
||||
|
||||
def preprocess(self):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
|
||||
policy = {}
|
||||
|
||||
# support gradient checkpointing
|
||||
# policy["DeepseekV3Model"] = ModulePolicyDescription(method_replacement={"forward": deepseek_v3_model_forward})
|
||||
|
||||
if self.shard_config.expert_parallel_size > 1:
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=EpDeepseekV3MoE,
|
||||
kwargs={
|
||||
"ep_group": self.shard_config.ep_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key="DeepseekV3DecoderLayer",
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# TODO: prevent casting to fp32
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key="DeepseekV3DecoderLayer",
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key="DeepseekV3Model",
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class DeepseekV3ModelPolicy(DeepseekV3Policy):
|
||||
pass
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
|
||||
pass
|
|
@ -68,6 +68,10 @@ class ShardConfig:
|
|||
def sequence_parallel_size(self):
|
||||
return self._sequence_parallel_size
|
||||
|
||||
@property
|
||||
def expert_parallel_size(self):
|
||||
return self._expert_parallel_size
|
||||
|
||||
def __post_init__(self):
|
||||
# turn on all optimization if all_optimization is set to True
|
||||
if self.enable_all_optimization:
|
||||
|
@ -103,6 +107,8 @@ class ShardConfig:
|
|||
else:
|
||||
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
|
||||
|
||||
self._expert_parallel_size = dist.get_world_size(self.ep_group) if self.ep_group else 1
|
||||
|
||||
def _turn_on_all_optimization(self):
|
||||
"""
|
||||
Turn on all optimization.
|
||||
|
|
|
@ -4,11 +4,13 @@ import resource
|
|||
import time
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from peft import LoraConfig
|
||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
@ -29,7 +31,7 @@ warnings.filterwarnings("ignore")
|
|||
|
||||
# We have lots of llamas for your choice!
|
||||
MODEL_CONFIGS = {
|
||||
"100m": lambda: AutoConfig.from_pretrained(
|
||||
"100m": AutoConfig.from_pretrained(
|
||||
"deepseek-ai/deepseek-moe-16b-base",
|
||||
max_position_embeddings=4096,
|
||||
num_hidden_layers=1,
|
||||
|
@ -44,20 +46,29 @@ MODEL_CONFIGS = {
|
|||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"7b": lambda: AutoConfig.from_pretrained(
|
||||
"7b": AutoConfig.from_pretrained(
|
||||
"deepseek-ai/deepseek-moe-16b-base",
|
||||
max_position_embeddings=4096,
|
||||
num_hidden_layers=13,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"14b": lambda: AutoConfig.from_pretrained(
|
||||
"14b": AutoConfig.from_pretrained(
|
||||
"deepseek-ai/deepseek-moe-16b-base",
|
||||
max_position_embeddings=4096,
|
||||
num_hidden_layers=26,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"v3-6b": AutoConfig.from_pretrained(
|
||||
"deepseek-ai/DeepSeek-V3",
|
||||
num_hidden_layers=5,
|
||||
first_k_dense_replace=2,
|
||||
n_routed_experts=32,
|
||||
vocab_size=8192,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
@ -119,6 +130,7 @@ def main():
|
|||
help="Sequence parallelism mode",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
||||
parser.add_argument("--enable_lora", action="store_true", help="Enable LoRA")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
|
@ -151,7 +163,7 @@ def main():
|
|||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
enable_sequence_parallelism=args.sp > 1,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_fused_normalization=get_accelerator().is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
|
@ -171,7 +183,10 @@ def main():
|
|||
# ==============================
|
||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||
|
||||
config = MODEL_CONFIGS[args.config]()
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
|
@ -189,11 +204,25 @@ def main():
|
|||
else nullcontext()
|
||||
)
|
||||
|
||||
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
|
||||
with init_ctx:
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16)
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16
|
||||
).to(torch.bfloat16)
|
||||
if args.enable_lora:
|
||||
booster.enable_lora(
|
||||
model,
|
||||
lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]),
|
||||
)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if model.__class__.__name__.startswith("DeepseekV3"):
|
||||
model.eval()
|
||||
# enable grad for moe layers
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "DeepseekV3MoE":
|
||||
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
|
|
@ -7,6 +7,7 @@ from torch import Tensor
|
|||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
|
@ -29,7 +30,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
|||
# tensor = tensor / world_size
|
||||
# return tensor.item()
|
||||
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float)
|
||||
tensor = torch.tensor([x], device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
|
|
@ -5,6 +5,7 @@ from .bloom import *
|
|||
from .chatglm2 import *
|
||||
from .command import *
|
||||
from .deepseek import *
|
||||
from .deepseek_v3 import *
|
||||
from .falcon import *
|
||||
from .gpt import *
|
||||
from .gptj import *
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# modified from tests/kit/model_zoo/transformers/mistral.py
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoConfig
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ===============================
|
||||
# Register single-sentence Mixtral
|
||||
# ===============================
|
||||
|
||||
|
||||
def data_gen():
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
|
||||
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
|
||||
# tokenized_input = tokenizer([input], return_tensors="pt")
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn = lambda x: x[0].mean()
|
||||
loss_fn_for_lm = lambda x: x.loss
|
||||
|
||||
|
||||
def init_deepseek():
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
"deepseek-ai/DeepSeek-V3",
|
||||
hidden_size=128,
|
||||
intermediate_size=320,
|
||||
kv_lora_rank=4,
|
||||
moe_intermediate_size=32,
|
||||
num_attention_heads=4,
|
||||
num_experts_per_tok=4,
|
||||
n_group=4,
|
||||
num_hidden_layers=3,
|
||||
num_key_value_heads=4,
|
||||
first_k_dense_replace=1,
|
||||
q_lora_rank=8,
|
||||
torch_dtype="bfloat16",
|
||||
n_routed_experts=16,
|
||||
topk_group=2,
|
||||
v_head_dim=32,
|
||||
qk_nope_head_dim=32,
|
||||
qk_rope_head_dim=32,
|
||||
trust_remote_code=True,
|
||||
vocab_size=2048,
|
||||
)
|
||||
|
||||
if hasattr(config, "pad_token_id"):
|
||||
config.pad_token_id = config.eos_token_id
|
||||
model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
# enable grad for moe layers
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "DeepseekV3MoE":
|
||||
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
|
||||
return model
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name="transformers_deepseek_v3",
|
||||
model_fn=init_deepseek,
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
|
@ -223,7 +223,6 @@ def run_forward_backward_with_hybrid_plugin(
|
|||
for k, v in data.items():
|
||||
unshard_test_data[k] = data[k].clone()
|
||||
|
||||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in shard_test_data.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
|
@ -248,7 +247,6 @@ def run_forward_backward_with_hybrid_plugin(
|
|||
sharded_loss = criterion(sharded_output)
|
||||
sharded_optimizer.backward(sharded_loss)
|
||||
|
||||
org_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in unshard_test_data.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||
seed_all(42)
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin
|
||||
)
|
||||
if enable_gradient_checkpointing:
|
||||
# org_model.gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
|
||||
org_model = org_model.to(torch.bfloat16)
|
||||
org_model.eval()
|
||||
sharded_model.eval()
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
assert_close(org_loss, sharded_loss)
|
||||
|
||||
param_dict = {n: p for n, p in org_model.named_parameters()}
|
||||
for n, p in sharded_model.unwrap().named_parameters():
|
||||
if n in param_dict:
|
||||
if booster.plugin.zero_stage == 0:
|
||||
grad = p.grad
|
||||
target_grad = param_dict[n].grad
|
||||
else:
|
||||
grad = sharded_optimizer.get_working_grad_by_param_id(id(p))
|
||||
pg = sharded_optimizer.param_to_pg[p]
|
||||
target_grad = param_dict[n].grad
|
||||
if target_grad is None:
|
||||
continue
|
||||
target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)]
|
||||
assert_close(grad, target_grad, atol=3e-1, rtol=0)
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
# zero 1
|
||||
(1, 4),
|
||||
(1, 2),
|
||||
],
|
||||
)
|
||||
def run_deepseek_v3_test(config: Tuple[int, ...]):
|
||||
zero_stage, ep_size = config
|
||||
plugin_config = dict(
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
ep_size=ep_size,
|
||||
zero_stage=zero_stage,
|
||||
overlap_communication=False,
|
||||
precision="bf16",
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek_v3")
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
||||
check_forward_backward(
|
||||
model_fn,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
loss_fn,
|
||||
plugin_config,
|
||||
)
|
||||
|
||||
|
||||
def check_deepseek_v3(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_deepseek_v3_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_deepseek_v3(world_size):
|
||||
spawn(check_deepseek_v3, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_deepseek_v3(world_size=4)
|
Loading…
Reference in New Issue