[Shardformer]fix the num_heads assert for llama model and qwen model (#5704)

* fix the num_heads assert

* fix the transformers import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the import

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5711/head
Wang Binluo 2024-05-10 15:33:39 +08:00 committed by GitHub
parent a3cc68ca93
commit 537f6a3855
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 30 deletions

View File

@ -10,16 +10,20 @@ from transformers.modeling_outputs import (
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
except ImportError:
Qwen2Model = "Qwen2Model"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2Attention = "Qwen2Attention"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
from transformers.utils import logging
@ -451,10 +455,6 @@ class Qwen2PipelineForwards:
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
from colossalai.shardformer.layer import ColoAttention
def forward(
self: Qwen2Attention,
hidden_states: torch.Tensor,

View File

@ -141,9 +141,11 @@ class LlamaPolicy(Policy):
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"):
assert (
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,

View File

@ -21,6 +21,26 @@ from ..modeling.qwen2 import (
get_qwen2_flash_attention_forward,
get_qwen2_model_forward_for_flash_attn,
)
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2FlashAttention2,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
Qwen2SdpaAttention,
)
except ImportError:
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
Qwen2Attention = "Qwen2Attention"
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
Qwen2SdpaAttention = "Qwen2SdpaAttention"
Qwen2DecoderLayer = "Qwen2DecoderLayer"
Qwen2Model = "Qwen2Model"
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
@ -45,21 +65,6 @@ class Qwen2Policy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2FlashAttention2,
Qwen2Model,
Qwen2SdpaAttention,
)
except ImportError:
Qwen2Attention = "Qwen2Attention"
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
Qwen2SdpaAttention = "Qwen2SdpaAttention"
Qwen2DecoderLayer = "Qwen2DecoderLayer"
Qwen2Model = "Qwen2Model"
ATTN_IMPLEMENTATION = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
@ -82,6 +87,13 @@ class Qwen2Policy(Policy):
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"):
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
@ -256,7 +268,6 @@ class Qwen2Policy(Policy):
class Qwen2ModelPolicy(Qwen2Policy):
def module_policy(self):
policy = super().module_policy()
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
if self.pipeline_stage_manager:
# set None as default
@ -277,10 +288,7 @@ class Qwen2ModelPolicy(Qwen2Policy):
class Qwen2ForCausalLMPolicy(Qwen2Policy):
def module_policy(self):
from transformers import Qwen2ForCausalLM
policy = super().module_policy()
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism:
@ -330,10 +338,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
def module_policy(self):
from transformers import Qwen2ForSequenceClassification
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {