fix flash attn (#5209)

pull/4976/merge
flybird11111 2024-01-03 14:39:53 +08:00 committed by GitHub
parent 365671be10
commit 451e9142b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -414,7 +414,7 @@ class LlamaPipelineForwards:
return {"hidden_states": hidden_states}
def get_llama_flash_attention_forward():
def get_llama_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
@ -470,14 +470,13 @@ def get_llama_flash_attention_forward():
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(

View File

@ -130,7 +130,7 @@ class LlamaPolicy(Policy):
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_forward(),
"forward": get_llama_flash_attention_forward(self.shard_config),
},
policy=policy,
target_key=LlamaAttention,
@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
policy = super().module_policy()
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {