mirror of https://github.com/hpcaitech/ColossalAI
[Shardformer]fix the num_heads assert for llama model and qwen model (#5704)
* fix the num_heads assert * fix the transformers import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the import --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5711/head
parent
a3cc68ca93
commit
537f6a3855
|
@ -10,16 +10,20 @@ from transformers.modeling_outputs import (
|
|||
|
||||
try:
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2Model = "Qwen2Model"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||
Qwen2Attention = "Qwen2Attention"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -451,10 +455,6 @@ class Qwen2PipelineForwards:
|
|||
|
||||
|
||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
|
@ -141,9 +141,11 @@ class LlamaPolicy(Policy):
|
|||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||
assert (
|
||||
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||
if hasattr(self.model.config, "num_key_value_heads"):
|
||||
assert (
|
||||
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
||||
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
|
|
|
@ -21,6 +21,26 @@ from ..modeling.qwen2 import (
|
|||
get_qwen2_flash_attention_forward,
|
||||
get_qwen2_model_forward_for_flash_attn,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2DecoderLayer,
|
||||
Qwen2FlashAttention2,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
Qwen2SdpaAttention,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
Qwen2Attention = "Qwen2Attention"
|
||||
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
|
||||
Qwen2SdpaAttention = "Qwen2SdpaAttention"
|
||||
Qwen2DecoderLayer = "Qwen2DecoderLayer"
|
||||
Qwen2Model = "Qwen2Model"
|
||||
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
|
||||
|
@ -45,21 +65,6 @@ class Qwen2Policy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
try:
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2DecoderLayer,
|
||||
Qwen2FlashAttention2,
|
||||
Qwen2Model,
|
||||
Qwen2SdpaAttention,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2Attention = "Qwen2Attention"
|
||||
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
|
||||
Qwen2SdpaAttention = "Qwen2SdpaAttention"
|
||||
Qwen2DecoderLayer = "Qwen2DecoderLayer"
|
||||
Qwen2Model = "Qwen2Model"
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": Qwen2Attention,
|
||||
"flash_attention_2": Qwen2FlashAttention2,
|
||||
|
@ -82,6 +87,13 @@ class Qwen2Policy(Policy):
|
|||
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||
if hasattr(self.model.config, "num_key_value_heads"):
|
||||
assert (
|
||||
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
|
@ -256,7 +268,6 @@ class Qwen2Policy(Policy):
|
|||
class Qwen2ModelPolicy(Qwen2Policy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
|
@ -277,10 +288,7 @@ class Qwen2ModelPolicy(Qwen2Policy):
|
|||
|
||||
class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||
def module_policy(self):
|
||||
from transformers import Qwen2ForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
setattr(self.shard_config, "causal_lm", True)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -330,10 +338,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||
|
||||
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
||||
def module_policy(self):
|
||||
from transformers import Qwen2ForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
|
|
Loading…
Reference in New Issue