mirror of https://github.com/hpcaitech/ColossalAI
fix flash attn (#5209)
parent
365671be10
commit
451e9142b8
|
@ -414,7 +414,7 @@ class LlamaPipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_llama_flash_attention_forward():
|
||||
def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
|
@ -470,14 +470,13 @@ def get_llama_flash_attention_forward():
|
|||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
if attention_mask != None:
|
||||
if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
if not torch.all(flash_attention_mask):
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
|
|
|
@ -130,7 +130,7 @@ class LlamaPolicy(Policy):
|
|||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_flash_attention_forward(),
|
||||
"forward": get_llama_flash_attention_forward(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaAttention,
|
||||
|
@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
|
||||
policy = super().module_policy()
|
||||
|
||||
setattr(self.shard_config, "causal_lm", True)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
|
|
Loading…
Reference in New Issue