mirror of https://github.com/hpcaitech/ColossalAI
[Shardformer]fix the num_heads assert for llama model and qwen model (#5704)
* fix the num_heads assert * fix the transformers import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the import --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5711/head
parent
a3cc68ca93
commit
537f6a3855
|
@ -10,16 +10,20 @@ from transformers.modeling_outputs import (
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.models.qwen2.modeling_qwen2 import (
|
from transformers.models.qwen2.modeling_qwen2 import (
|
||||||
|
Qwen2Attention,
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
Qwen2ForSequenceClassification,
|
Qwen2ForSequenceClassification,
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
_prepare_4d_causal_attention_mask,
|
_prepare_4d_causal_attention_mask,
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Qwen2Model = "Qwen2Model"
|
Qwen2Model = "Qwen2Model"
|
||||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
|
||||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||||
|
Qwen2Attention = "Qwen2Attention"
|
||||||
|
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||||
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
@ -451,10 +455,6 @@ class Qwen2PipelineForwards:
|
||||||
|
|
||||||
|
|
||||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
|
|
||||||
|
|
||||||
from colossalai.shardformer.layer import ColoAttention
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self: Qwen2Attention,
|
self: Qwen2Attention,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
@ -141,9 +141,11 @@ class LlamaPolicy(Policy):
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
assert (
|
if hasattr(self.model.config, "num_key_value_heads"):
|
||||||
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
assert (
|
||||||
), f"The number of key_value heads must be divisible by tensor parallel size."
|
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
||||||
|
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
|
|
@ -21,6 +21,26 @@ from ..modeling.qwen2 import (
|
||||||
get_qwen2_flash_attention_forward,
|
get_qwen2_flash_attention_forward,
|
||||||
get_qwen2_model_forward_for_flash_attn,
|
get_qwen2_model_forward_for_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import (
|
||||||
|
Qwen2Attention,
|
||||||
|
Qwen2DecoderLayer,
|
||||||
|
Qwen2FlashAttention2,
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
Qwen2ForSequenceClassification,
|
||||||
|
Qwen2Model,
|
||||||
|
Qwen2SdpaAttention,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||||
|
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||||
|
Qwen2Attention = "Qwen2Attention"
|
||||||
|
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
|
||||||
|
Qwen2SdpaAttention = "Qwen2SdpaAttention"
|
||||||
|
Qwen2DecoderLayer = "Qwen2DecoderLayer"
|
||||||
|
Qwen2Model = "Qwen2Model"
|
||||||
|
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
|
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
|
||||||
|
@ -45,21 +65,6 @@ class Qwen2Policy(Policy):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
try:
|
|
||||||
from transformers.models.qwen2.modeling_qwen2 import (
|
|
||||||
Qwen2Attention,
|
|
||||||
Qwen2DecoderLayer,
|
|
||||||
Qwen2FlashAttention2,
|
|
||||||
Qwen2Model,
|
|
||||||
Qwen2SdpaAttention,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
Qwen2Attention = "Qwen2Attention"
|
|
||||||
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
|
|
||||||
Qwen2SdpaAttention = "Qwen2SdpaAttention"
|
|
||||||
Qwen2DecoderLayer = "Qwen2DecoderLayer"
|
|
||||||
Qwen2Model = "Qwen2Model"
|
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
ATTN_IMPLEMENTATION = {
|
||||||
"eager": Qwen2Attention,
|
"eager": Qwen2Attention,
|
||||||
"flash_attention_2": Qwen2FlashAttention2,
|
"flash_attention_2": Qwen2FlashAttention2,
|
||||||
|
@ -82,6 +87,13 @@ class Qwen2Policy(Policy):
|
||||||
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
|
if hasattr(self.model.config, "num_key_value_heads"):
|
||||||
|
assert (
|
||||||
|
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
@ -256,7 +268,6 @@ class Qwen2Policy(Policy):
|
||||||
class Qwen2ModelPolicy(Qwen2Policy):
|
class Qwen2ModelPolicy(Qwen2Policy):
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
||||||
|
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager:
|
||||||
# set None as default
|
# set None as default
|
||||||
|
@ -277,10 +288,7 @@ class Qwen2ModelPolicy(Qwen2Policy):
|
||||||
|
|
||||||
class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import Qwen2ForCausalLM
|
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
setattr(self.shard_config, "causal_lm", True)
|
setattr(self.shard_config, "causal_lm", True)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -330,10 +338,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||||
|
|
||||||
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import Qwen2ForSequenceClassification
|
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for sequence classification
|
# add a new item for sequence classification
|
||||||
new_item = {
|
new_item = {
|
||||||
|
|
Loading…
Reference in New Issue