mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support ep for deepseek v3 (#6185)
* [feature] support ep for deepseek v3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * [shardformer] fix deepseek v3 init * [lazy] fit lora for lazy init * [example] support npu for deepseek v3 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6191/head
parent
17062c83b9
commit
2b415e5999
|
@ -19,7 +19,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||||
HybridParallelPlugin,
|
HybridParallelPlugin,
|
||||||
HybridParallelZeroOptimizer,
|
HybridParallelZeroOptimizer,
|
||||||
get_param_info,
|
get_param_info,
|
||||||
reinitialize_optimizer,
|
|
||||||
)
|
)
|
||||||
from colossalai.checkpoint_io import MoECheckpointIO
|
from colossalai.checkpoint_io import MoECheckpointIO
|
||||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||||
|
@ -468,18 +467,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
use_fp8=self.use_fp8,
|
use_fp8=self.use_fp8,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
if self.ep_size > 1:
|
|
||||||
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
|
|
||||||
# but the optimizer is not aware of ep, so we need to update the optimizer
|
|
||||||
reinitialize_optimizer(optimizer, model)
|
|
||||||
|
|
||||||
if self.zero_stage == 0:
|
if self.zero_stage == 0:
|
||||||
is_zero = False
|
is_zero = False
|
||||||
if self.precision in ["fp16", "bf16"]:
|
if self.precision in ["fp16", "bf16"]:
|
||||||
optimizer = HybridParallelAMPOptimizer(
|
optimizer = HybridParallelAMPOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
max_norm=self.max_norm,
|
max_norm=self.max_norm,
|
||||||
|
@ -489,7 +483,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
optimizer = HybridParallelNaiveOptimizer(
|
optimizer = HybridParallelNaiveOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
max_norm=self.max_norm,
|
max_norm=self.max_norm,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
|
@ -507,7 +501,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
optimizer = MoeHybridParallelZeroOptimizer(
|
optimizer = MoeHybridParallelZeroOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
dp_process_group=self.mixed_dp_group,
|
dp_process_group=self.mixed_dp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
|
|
|
@ -64,7 +64,10 @@ class ProcessGroupMesh:
|
||||||
system resources.
|
system resources.
|
||||||
"""
|
"""
|
||||||
for group in self._ranks_to_group.values():
|
for group in self._ranks_to_group.values():
|
||||||
dist.destroy_process_group(group)
|
try:
|
||||||
|
dist.destroy_process_group(group)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Manually clear all process groups to save memory
|
# Manually clear all process groups to save memory
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
|
@ -104,7 +104,7 @@ def _data_tolist(tensor: torch.Tensor) -> list:
|
||||||
return tensor.data.tolist()
|
return tensor.data.tolist()
|
||||||
|
|
||||||
|
|
||||||
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
|
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad=None) -> torch.Tensor:
|
||||||
"""Convert a lazy tensor's class to target's class, with target's data.
|
"""Convert a lazy tensor's class to target's class, with target's data.
|
||||||
|
|
||||||
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
|
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
|
||||||
|
@ -117,13 +117,14 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: the converted tensor
|
torch.Tensor: the converted tensor
|
||||||
"""
|
"""
|
||||||
|
requires_grad = target.requires_grad if requires_grad is None else requires_grad
|
||||||
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
|
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
|
||||||
tensor.__class__ = cls_to_become
|
tensor.__class__ = cls_to_become
|
||||||
if cls_to_become is Parameter:
|
if cls_to_become is Parameter:
|
||||||
# to fit UninitializedParameter
|
# to fit UninitializedParameter
|
||||||
delattr(tensor, "_is_param")
|
delattr(tensor, "_is_param")
|
||||||
tensor.data = target
|
tensor.data = target
|
||||||
tensor.requires_grad = target.requires_grad
|
tensor.requires_grad = requires_grad
|
||||||
# subclass of torch.Tensor does not have tolist() method
|
# subclass of torch.Tensor does not have tolist() method
|
||||||
# overwrite this method after materialization or distribution
|
# overwrite this method after materialization or distribution
|
||||||
tensor.tolist = MethodType(_data_tolist, tensor)
|
tensor.tolist = MethodType(_data_tolist, tensor)
|
||||||
|
@ -212,9 +213,10 @@ class LazyTensor(torch.Tensor):
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The materialized tensor (self).
|
torch.Tensor: The materialized tensor (self).
|
||||||
"""
|
"""
|
||||||
|
requires_grad = self.requires_grad
|
||||||
target = self._materialize_data()
|
target = self._materialize_data()
|
||||||
self.clean()
|
self.clean()
|
||||||
return _convert_cls(self, target)
|
return _convert_cls(self, target, requires_grad=requires_grad)
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
|
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
|
||||||
|
|
|
@ -0,0 +1,277 @@
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.moe._operation import (
|
||||||
|
DPGradScalerIn,
|
||||||
|
DPGradScalerOut,
|
||||||
|
EPGradScalerIn,
|
||||||
|
EPGradScalerOut,
|
||||||
|
all_to_all_uneven,
|
||||||
|
)
|
||||||
|
from colossalai.shardformer.layer.linear import ParallelModule
|
||||||
|
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
|
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||||
|
|
||||||
|
|
||||||
|
class EpDeepseekV3MoE(ParallelModule):
|
||||||
|
"""
|
||||||
|
A mixed expert module containing shared experts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||||
|
|
||||||
|
def setup_process_groups(
|
||||||
|
self,
|
||||||
|
moe_dp_group: ProcessGroup,
|
||||||
|
ep_group: ProcessGroup,
|
||||||
|
):
|
||||||
|
assert moe_dp_group is not None
|
||||||
|
assert ep_group is not None
|
||||||
|
|
||||||
|
self.ep_size = dist.get_world_size(ep_group)
|
||||||
|
self.ep_rank = dist.get_rank(ep_group)
|
||||||
|
self.num_experts = self.config.n_routed_experts
|
||||||
|
assert self.num_experts % self.ep_size == 0
|
||||||
|
|
||||||
|
self.ep_group = ep_group
|
||||||
|
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||||
|
self.experts_per_rank = self.num_experts_per_ep
|
||||||
|
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||||
|
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||||
|
|
||||||
|
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||||
|
|
||||||
|
# setup moe_dp group
|
||||||
|
self.moe_dp_group = moe_dp_group
|
||||||
|
self.moe_dp_size = dist.get_world_size(moe_dp_group)
|
||||||
|
|
||||||
|
for p in self.experts.parameters():
|
||||||
|
set_moe_tensor_ep_group(p, ep_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module,
|
||||||
|
moe_dp_group: ProcessGroup,
|
||||||
|
ep_group: ProcessGroup,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> "EpDeepseekV3MoE":
|
||||||
|
if module.__class__.__name__ != "DeepseekV3MLP":
|
||||||
|
module.__class__ = EpDeepseekV3MoE
|
||||||
|
module.setup_process_groups(moe_dp_group, ep_group)
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
identity = hidden_states
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
topk_idx, topk_weight = self.gate(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
||||||
|
if self.config.n_shared_experts is not None:
|
||||||
|
y = y + self.shared_experts(identity)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||||
|
cnts.scatter_(1, topk_ids, 1)
|
||||||
|
tokens_per_expert = cnts.sum(dim=0)
|
||||||
|
idxs = topk_ids.view(-1).argsort()
|
||||||
|
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||||
|
if self.ep_size > 1:
|
||||||
|
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
||||||
|
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
|
||||||
|
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)
|
||||||
|
|
||||||
|
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist()
|
||||||
|
input_split_sizes = tokens_per_ep_rank.tolist()
|
||||||
|
|
||||||
|
gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group)
|
||||||
|
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
|
||||||
|
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
|
||||||
|
s = 0
|
||||||
|
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
||||||
|
gatherd_idxs[s : s + k] = i % self.experts_per_rank
|
||||||
|
s += k
|
||||||
|
gatherd_idxs = gatherd_idxs.argsort()
|
||||||
|
sorted_tokens = gathered_tokens[gatherd_idxs]
|
||||||
|
tokens_per_expert = tokens_per_expert_post_gather
|
||||||
|
|
||||||
|
# moe-dp related code
|
||||||
|
activate_experts = tokens_per_expert_post_gather > 0
|
||||||
|
activate_experts = activate_experts.int()
|
||||||
|
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||||
|
|
||||||
|
# ep related code
|
||||||
|
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size)
|
||||||
|
|
||||||
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
start_idx = 0
|
||||||
|
for i, num_tokens in enumerate(tokens_per_expert):
|
||||||
|
end_idx = start_idx + num_tokens
|
||||||
|
if num_tokens == 0:
|
||||||
|
continue
|
||||||
|
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||||
|
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||||
|
# moe-dp related code
|
||||||
|
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i])
|
||||||
|
expert_out = expert(tokens_for_this_expert)
|
||||||
|
# moe-dp related code
|
||||||
|
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i])
|
||||||
|
outputs.append(expert_out)
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
if len(outputs) > 0:
|
||||||
|
outs = torch.cat(outputs, dim=0)
|
||||||
|
else:
|
||||||
|
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}"
|
||||||
|
outs = sorted_tokens
|
||||||
|
|
||||||
|
if self.ep_size > 1:
|
||||||
|
outs = EPGradScalerOut.apply(outs, self.ep_size)
|
||||||
|
new_x = torch.empty_like(outs)
|
||||||
|
new_x[gatherd_idxs] = outs
|
||||||
|
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group)
|
||||||
|
outs = gathered_tokens
|
||||||
|
|
||||||
|
new_x = torch.empty_like(outs)
|
||||||
|
new_x[idxs] = outs
|
||||||
|
final_out = (
|
||||||
|
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1))
|
||||||
|
.sum(dim=1)
|
||||||
|
.type(new_x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_out
|
||||||
|
|
||||||
|
|
||||||
|
def deepseek_v3_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
if use_cache:
|
||||||
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for i, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and i > 0:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
|
@ -167,6 +167,13 @@ _POLICY_LIST = {
|
||||||
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
|
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
|
||||||
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
|
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
|
||||||
),
|
),
|
||||||
|
# DeepseekV3
|
||||||
|
"transformers_modules.modeling_deepseek.DeepseekV3Model": PolicyLocation(
|
||||||
|
file_name="deepseek_v3", class_name="DeepseekV3ModelPolicy"
|
||||||
|
),
|
||||||
|
"transformers_modules.modeling_deepseek.DeepseekV3ForCausalLM": PolicyLocation(
|
||||||
|
file_name="deepseek_v3", class_name="DeepseekV3ForCausalLMPolicy"
|
||||||
|
),
|
||||||
# Falcon
|
# Falcon
|
||||||
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
|
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
|
||||||
file_name="falcon", class_name="FalconModelPolicy"
|
file_name="falcon", class_name="FalconModelPolicy"
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.shardformer.layer import FusedRMSNorm
|
||||||
|
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE
|
||||||
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Policy(Policy):
|
||||||
|
def config_sanity_check(self):
|
||||||
|
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
|
||||||
|
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism"
|
||||||
|
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
|
||||||
|
policy = {}
|
||||||
|
|
||||||
|
# support gradient checkpointing
|
||||||
|
# policy["DeepseekV3Model"] = ModulePolicyDescription(method_replacement={"forward": deepseek_v3_model_forward})
|
||||||
|
|
||||||
|
if self.shard_config.expert_parallel_size > 1:
|
||||||
|
# expert parallel
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp",
|
||||||
|
target_module=EpDeepseekV3MoE,
|
||||||
|
kwargs={
|
||||||
|
"ep_group": self.shard_config.ep_group,
|
||||||
|
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key="DeepseekV3DecoderLayer",
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimization configuration
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
# TODO: prevent casting to fp32
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="input_layernorm",
|
||||||
|
target_module=FusedRMSNorm,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="post_attention_layernorm",
|
||||||
|
target_module=FusedRMSNorm,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key="DeepseekV3DecoderLayer",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="norm",
|
||||||
|
target_module=FusedRMSNorm,
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key="DeepseekV3Model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3ModelPolicy(DeepseekV3Policy):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
|
||||||
|
pass
|
|
@ -68,6 +68,10 @@ class ShardConfig:
|
||||||
def sequence_parallel_size(self):
|
def sequence_parallel_size(self):
|
||||||
return self._sequence_parallel_size
|
return self._sequence_parallel_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expert_parallel_size(self):
|
||||||
|
return self._expert_parallel_size
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# turn on all optimization if all_optimization is set to True
|
# turn on all optimization if all_optimization is set to True
|
||||||
if self.enable_all_optimization:
|
if self.enable_all_optimization:
|
||||||
|
@ -103,6 +107,8 @@ class ShardConfig:
|
||||||
else:
|
else:
|
||||||
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
|
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
|
||||||
|
|
||||||
|
self._expert_parallel_size = dist.get_world_size(self.ep_group) if self.ep_group else 1
|
||||||
|
|
||||||
def _turn_on_all_optimization(self):
|
def _turn_on_all_optimization(self):
|
||||||
"""
|
"""
|
||||||
Turn on all optimization.
|
Turn on all optimization.
|
||||||
|
|
|
@ -4,11 +4,13 @@ import resource
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from data_utils import RandomDataset
|
from data_utils import RandomDataset
|
||||||
from model_utils import format_numel_str, get_model_numel
|
from model_utils import format_numel_str, get_model_numel
|
||||||
|
from peft import LoraConfig
|
||||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
@ -29,7 +31,7 @@ warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
# We have lots of llamas for your choice!
|
# We have lots of llamas for your choice!
|
||||||
MODEL_CONFIGS = {
|
MODEL_CONFIGS = {
|
||||||
"100m": lambda: AutoConfig.from_pretrained(
|
"100m": AutoConfig.from_pretrained(
|
||||||
"deepseek-ai/deepseek-moe-16b-base",
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
max_position_embeddings=4096,
|
max_position_embeddings=4096,
|
||||||
num_hidden_layers=1,
|
num_hidden_layers=1,
|
||||||
|
@ -44,20 +46,29 @@ MODEL_CONFIGS = {
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
),
|
),
|
||||||
"7b": lambda: AutoConfig.from_pretrained(
|
"7b": AutoConfig.from_pretrained(
|
||||||
"deepseek-ai/deepseek-moe-16b-base",
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
max_position_embeddings=4096,
|
max_position_embeddings=4096,
|
||||||
num_hidden_layers=13,
|
num_hidden_layers=13,
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
),
|
),
|
||||||
"14b": lambda: AutoConfig.from_pretrained(
|
"14b": AutoConfig.from_pretrained(
|
||||||
"deepseek-ai/deepseek-moe-16b-base",
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
max_position_embeddings=4096,
|
max_position_embeddings=4096,
|
||||||
num_hidden_layers=26,
|
num_hidden_layers=26,
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
),
|
),
|
||||||
|
"v3-6b": AutoConfig.from_pretrained(
|
||||||
|
"deepseek-ai/DeepSeek-V3",
|
||||||
|
num_hidden_layers=5,
|
||||||
|
first_k_dense_replace=2,
|
||||||
|
n_routed_experts=32,
|
||||||
|
vocab_size=8192,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,6 +130,7 @@ def main():
|
||||||
help="Sequence parallelism mode",
|
help="Sequence parallelism mode",
|
||||||
)
|
)
|
||||||
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
||||||
|
parser.add_argument("--enable_lora", action="store_true", help="Enable LoRA")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
colossalai.launch_from_torch()
|
colossalai.launch_from_torch()
|
||||||
|
@ -151,7 +163,7 @@ def main():
|
||||||
sp_size=args.sp,
|
sp_size=args.sp,
|
||||||
sequence_parallelism_mode=args.sp_mode,
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
enable_sequence_parallelism=args.sp > 1,
|
enable_sequence_parallelism=args.sp > 1,
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
microbatch_size=args.mbs,
|
microbatch_size=args.mbs,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
|
@ -171,7 +183,10 @@ def main():
|
||||||
# ==============================
|
# ==============================
|
||||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||||
|
|
||||||
config = MODEL_CONFIGS[args.config]()
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
|
||||||
torch.cuda.manual_seed(42)
|
torch.cuda.manual_seed(42)
|
||||||
|
|
||||||
|
@ -189,11 +204,25 @@ def main():
|
||||||
else nullcontext()
|
else nullcontext()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16)
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16
|
||||||
|
).to(torch.bfloat16)
|
||||||
|
if args.enable_lora:
|
||||||
|
booster.enable_lora(
|
||||||
|
model,
|
||||||
|
lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]),
|
||||||
|
)
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
if model.__class__.__name__.startswith("DeepseekV3"):
|
||||||
|
model.eval()
|
||||||
|
# enable grad for moe layers
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__ == "DeepseekV3MoE":
|
||||||
|
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
|
||||||
|
|
||||||
model_numel = get_model_numel(model)
|
model_numel = get_model_numel(model)
|
||||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
|
|
|
@ -7,6 +7,7 @@ from torch import Tensor
|
||||||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def divide(x: float, y: float) -> float:
|
def divide(x: float, y: float) -> float:
|
||||||
|
@ -29,7 +30,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||||
# tensor = tensor / world_size
|
# tensor = tensor / world_size
|
||||||
# return tensor.item()
|
# return tensor.item()
|
||||||
|
|
||||||
tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float)
|
tensor = torch.tensor([x], device=get_current_device(), dtype=torch.float)
|
||||||
dist.all_reduce(tensor)
|
dist.all_reduce(tensor)
|
||||||
tensor = tensor / world_size
|
tensor = tensor / world_size
|
||||||
return tensor.item()
|
return tensor.item()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .bloom import *
|
||||||
from .chatglm2 import *
|
from .chatglm2 import *
|
||||||
from .command import *
|
from .command import *
|
||||||
from .deepseek import *
|
from .deepseek import *
|
||||||
|
from .deepseek_v3 import *
|
||||||
from .falcon import *
|
from .falcon import *
|
||||||
from .gpt import *
|
from .gpt import *
|
||||||
from .gptj import *
|
from .gptj import *
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
# modified from tests/kit/model_zoo/transformers/mistral.py
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# Register single-sentence Mixtral
|
||||||
|
# ===============================
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
# Generated from following code snippet
|
||||||
|
#
|
||||||
|
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
|
||||||
|
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
|
||||||
|
# tokenized_input = tokenizer([input], return_tensors="pt")
|
||||||
|
# input_ids = tokenized_input['input_ids']
|
||||||
|
# attention_mask = tokenized_input['attention_mask']
|
||||||
|
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
|
||||||
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_lm():
|
||||||
|
# LM data gen
|
||||||
|
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||||
|
data = data_gen()
|
||||||
|
data["labels"] = data["input_ids"].clone()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# define output transform function
|
||||||
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
|
# define loss function
|
||||||
|
loss_fn = lambda x: x[0].mean()
|
||||||
|
loss_fn_for_lm = lambda x: x.loss
|
||||||
|
|
||||||
|
|
||||||
|
def init_deepseek():
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
"deepseek-ai/DeepSeek-V3",
|
||||||
|
hidden_size=128,
|
||||||
|
intermediate_size=320,
|
||||||
|
kv_lora_rank=4,
|
||||||
|
moe_intermediate_size=32,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_experts_per_tok=4,
|
||||||
|
n_group=4,
|
||||||
|
num_hidden_layers=3,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
first_k_dense_replace=1,
|
||||||
|
q_lora_rank=8,
|
||||||
|
torch_dtype="bfloat16",
|
||||||
|
n_routed_experts=16,
|
||||||
|
topk_group=2,
|
||||||
|
v_head_dim=32,
|
||||||
|
qk_nope_head_dim=32,
|
||||||
|
qk_rope_head_dim=32,
|
||||||
|
trust_remote_code=True,
|
||||||
|
vocab_size=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(config, "pad_token_id"):
|
||||||
|
config.pad_token_id = config.eos_token_id
|
||||||
|
model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||||
|
# enable grad for moe layers
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__ == "DeepseekV3MoE":
|
||||||
|
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_deepseek_v3",
|
||||||
|
model_fn=init_deepseek,
|
||||||
|
data_gen_fn=data_gen_for_lm,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_lm,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
|
@ -223,7 +223,6 @@ def run_forward_backward_with_hybrid_plugin(
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
unshard_test_data[k] = data[k].clone()
|
unshard_test_data[k] = data[k].clone()
|
||||||
|
|
||||||
sharded_model.train()
|
|
||||||
if booster.plugin.stage_manager is not None:
|
if booster.plugin.stage_manager is not None:
|
||||||
for k, v in shard_test_data.items():
|
for k, v in shard_test_data.items():
|
||||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||||
|
@ -248,7 +247,6 @@ def run_forward_backward_with_hybrid_plugin(
|
||||||
sharded_loss = criterion(sharded_output)
|
sharded_loss = criterion(sharded_output)
|
||||||
sharded_optimizer.backward(sharded_loss)
|
sharded_optimizer.backward(sharded_loss)
|
||||||
|
|
||||||
org_model.train()
|
|
||||||
if booster.plugin.stage_manager is not None:
|
if booster.plugin.stage_manager is not None:
|
||||||
for k, v in unshard_test_data.items():
|
for k, v in unshard_test_data.items():
|
||||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from colossalai.testing.random import seed_all
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||||
|
seed_all(42)
|
||||||
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
|
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin
|
||||||
|
)
|
||||||
|
if enable_gradient_checkpointing:
|
||||||
|
# org_model.gradient_checkpointing_enable()
|
||||||
|
sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
|
org_model = org_model.to(torch.bfloat16)
|
||||||
|
org_model.eval()
|
||||||
|
sharded_model.eval()
|
||||||
|
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_close(org_loss, sharded_loss)
|
||||||
|
|
||||||
|
param_dict = {n: p for n, p in org_model.named_parameters()}
|
||||||
|
for n, p in sharded_model.unwrap().named_parameters():
|
||||||
|
if n in param_dict:
|
||||||
|
if booster.plugin.zero_stage == 0:
|
||||||
|
grad = p.grad
|
||||||
|
target_grad = param_dict[n].grad
|
||||||
|
else:
|
||||||
|
grad = sharded_optimizer.get_working_grad_by_param_id(id(p))
|
||||||
|
pg = sharded_optimizer.param_to_pg[p]
|
||||||
|
target_grad = param_dict[n].grad
|
||||||
|
if target_grad is None:
|
||||||
|
continue
|
||||||
|
target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)]
|
||||||
|
assert_close(grad, target_grad, atol=3e-1, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
# zero 1
|
||||||
|
(1, 4),
|
||||||
|
(1, 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_deepseek_v3_test(config: Tuple[int, ...]):
|
||||||
|
zero_stage, ep_size = config
|
||||||
|
plugin_config = dict(
|
||||||
|
pp_size=1,
|
||||||
|
tp_size=1,
|
||||||
|
ep_size=ep_size,
|
||||||
|
zero_stage=zero_stage,
|
||||||
|
overlap_communication=False,
|
||||||
|
precision="bf16",
|
||||||
|
find_unused_parameters=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek_v3")
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
|
||||||
|
check_forward_backward(
|
||||||
|
model_fn,
|
||||||
|
data_gen_fn,
|
||||||
|
output_transform_fn,
|
||||||
|
loss_fn,
|
||||||
|
plugin_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_deepseek_v3(rank, world_size, port):
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_deepseek_v3_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_deepseek_v3(world_size):
|
||||||
|
spawn(check_deepseek_v3, world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_deepseek_v3(world_size=4)
|
Loading…
Reference in New Issue