Browse Source

[moe] deepseek moe sp support

colossalchat
haze188 4 months ago committed by Hongxin Liu
parent
commit
b2952a5982
  1. 299
      colossalai/shardformer/modeling/deepseek.py
  2. 78
      colossalai/shardformer/policies/deepseek.py
  3. 1
      tests/kit/model_zoo/transformers/__init__.py
  4. 84
      tests/kit/model_zoo/transformers/deepseek.py
  5. 28
      tests/test_shardformer/test_model/test_shard_deepseek.py
  6. 231
      tests/test_shardformer/test_model/test_shard_deepseek_ghz.py

299
colossalai/shardformer/modeling/deepseek.py

@ -1,12 +1,18 @@
from typing import List, Optional
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
@ -18,6 +24,11 @@ from colossalai.moe._operation import (
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
@ -362,7 +373,14 @@ class DeepseekPipelineForwards:
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
@ -479,3 +497,276 @@ class DeepseekPipelineForwards:
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
return out
def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
# DeepseekFlashAttention2 attention does not support output_attentions
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
output_attentions = False
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
rank = dist.get_rank()
print(f"{rank=}, hidden states:{hidden_states.shape}")
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
rank = dist.get_rank()
print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size()
print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
)
print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
print(
f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}"
)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (DeepseekRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return forward
def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
)
use_cache = False
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward

78
colossalai/shardformer/policies/deepseek.py

@ -7,8 +7,14 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
from colossalai.shardformer.modeling.deepseek import (
DeepseekPipelineForwards,
EPDeepseekMoE,
get_deepseek_flash_attention_forward,
get_deepseek_flash_attention_model_forward,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
@ -19,6 +25,13 @@ class DeepseekPolicy(Policy):
pass
def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
"""
Because transformers library's bug for AutoModel/AutoConfig, who pop “attn_implement” twice from modeling_utils.py and configuration_utils.py.
This bug causes attn_cls to be set to sdpa. Here we assign it to "flash_attention_2".
"""
# self.origin_attn_implement = "flash_attention_2"
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
@ -31,17 +44,61 @@ class DeepseekPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
ATTN_IMPLEMENTATION = {
"eager": "DeepseekAttention",
"flash_attention_2": "DeepseekFlashAttention2",
"sdpa": "DeepseekSdpaAttention",
}
policy = {}
print(f"{self.origin_attn_implement=}")
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "all_to_all":
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_sequence_parallelism:
if self.pipeline_stage_manager is not None:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
# if both are enabled, one of them will be ignored
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
raise NotImplementedError(
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
print(f"{attn_cls=}")
self.append_or_create_method_replacement(
description={
"forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": get_deepseek_flash_attention_model_forward(
self.shard_config,
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key="DeepseekModel",
)
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_tensor_parallelism:
# tensor parallelism for non-moe params
assert (
@ -78,6 +135,16 @@ class DeepseekPolicy(Policy):
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key="DeepseekModel",
)
if self.shard_config.ep_group:
# expert parallel
@ -105,10 +172,12 @@ class DeepseekPolicy(Policy):
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@ -119,6 +188,7 @@ class DeepseekPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key="DeepseekModel",

1
tests/kit/model_zoo/transformers/__init__.py

@ -4,6 +4,7 @@ from .blip2 import *
from .bloom import *
from .chatglm2 import *
from .command import *
from .deepseek import *
from .falcon import *
from .gpt import *
from .gptj import *

84
tests/kit/model_zoo/transformers/deepseek.py

@ -0,0 +1,84 @@
# modified from tests/kit/model_zoo/transformers/mistral.py
import torch
import transformers
from transformers import AutoConfig
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence Mixtral
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_mixtral_model = lambda x: x[0].mean()
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
def init_deepseek():
config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
hidden_size=32,
intermediate_size=32,
moe_intermediate_size=32,
num_hidden_layers=2,
num_attention_heads=8,
num_key_value_heads=8,
# vocab_size=2200,
first_k_dense_replace=1,
attn_implementation="flash_attention_2",
torch_dtype="float16",
n_routed_experts=8,
trust_remote_code=True,
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
print(config)
model = transformers.AutoModel.from_config(config, trust_remote_code=True)
return model
model_zoo.register(
name="transformers_deepseek",
model_fn=init_deepseek,
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_mixtral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)

28
tests/test_shardformer/test_model/test_shard_deepseek.py

@ -36,8 +36,8 @@ CHECKED_CONFIG = [ # FOR_WORLD=8
[
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
(2, 1, 1, 2, 1),
# (2, 1, 1, 1, 2), # TODO support deepseek sp
# (2, 1, 1, 2, 1),
(2, 1, 1, 1, 2),
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
],
@ -69,14 +69,22 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
booster = Booster(plugin=plugin)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
config.num_hidden_layers = 2
config.num_attention_heads = NUM_HEADS
config.num_key_value_heads = NUM_HEADS
config.n_routed_experts = NUM_EXPERTS
config.num_experts_per_tok = TOP_K
# config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
moe_intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=2,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
first_k_dense_replace=1,
attn_implementation="flash_attention_2",
torch_dtype="float16",
n_routed_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
trust_remote_code=True,
)
# init model with the same seed
seed_all(10086)

231
tests/test_shardformer/test_model/test_shard_deepseek_ghz.py

@ -0,0 +1,231 @@
# modified from test_shard_mistral.py
import os
import pytest
import torch
import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
)
org_model = org_model.to(torch.float16)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
# unwrap model
mixtral_model = unwrap_model(org_model, "DeepseekModel", "model")
shard_mixtral_model = unwrap_model(sharded_model, "DeepseekModel", "model")
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
# Check the grad when using ZeRO-1 and ZeRO-2
if (
# booster.plugin.zero_stage in [1, 2]
booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
rank = dist.get_rank()
for n, p in shard_mixtral_model.named_parameters():
zero_grad = sharded_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
continue
assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 5e-5, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
mixtral_model,
shard_mixtral_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
mixtral_model,
shard_mixtral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check grads
check_all_grad_tensors(grads_to_check)
for n, p in shard_mixtral_model.named_parameters():
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
for n, p in shard_mixtral_model.named_parameters():
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
try:
check_weight(
mixtral_model,
shard_mixtral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
except Exception as e:
rank = dist.get_rank()
print(f"{rank=}, Failed config: {test_config}")
raise e
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
# {
# "tp_size": 1,
# "pp_size": 1,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 0,
# "overlap_communication": False,
# "precision": "fp16",
# }, # [dp(4)] + [moe_dp(4)]
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [moe_pp(2)]
# {
# "tp_size": 1,
# "pp_size": 2,
# "ep_size": 2,
# "num_microbatches": 2,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp16",
# "initial_scale": 1,
# "find_unused_parameters": True,
# }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"ep_size": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp16",
"initial_scale": 1,
"find_unused_parameters": True,
},
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 2,
# "zero_stage": 0,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 4,
# "overlap_communication": False,
# "zero_stage": 0,
# "precision": "fp32"
# }, # full dp for non-moe and full ep for moe
],
)
def run_deepseek_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_deepseek(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_deepseek_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_mixtral():
spawn(check_deepseek, 4)
if __name__ == "__main__":
test_mixtral()
Loading…
Cancel
Save