mirror of https://github.com/hpcaitech/ColossalAI
[shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)
* fix flash attn * fix fixpull/5099/head
parent
75af66cd81
commit
aae496631c
|
@ -51,7 +51,8 @@ def get_flash_core_attention_forward():
|
|||
attn_mask_type = AttnMaskType.causal
|
||||
else:
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
if not torch.all(flash_attention_mask):
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.hidden_size_per_partition,
|
||||
|
|
|
@ -771,11 +771,12 @@ def get_gpt2_flash_attention_forward():
|
|||
attn_mask_type = AttnMaskType.causal
|
||||
flash_attention_mask = None
|
||||
if attention_mask != None:
|
||||
if attn_mask_type == AttnMaskType.causal:
|
||||
attn_mask_type == AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_mask_type = AttnMaskType.padding
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
if not torch.all(flash_attention_mask):
|
||||
if attn_mask_type == AttnMaskType.causal:
|
||||
attn_mask_type == AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_mask_type = AttnMaskType.padding
|
||||
|
||||
scale = value.size(-1) ** -0.5
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
|
|
|
@ -465,7 +465,8 @@ def get_llama_flash_attention_forward():
|
|||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
if not torch.all(flash_attention_mask):
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
|
|
|
@ -581,7 +581,8 @@ def get_opt_flash_attention_forward():
|
|||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
if not torch.all(flash_attention_mask):
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
|
||||
|
|
|
@ -106,7 +106,10 @@ def get_whisper_flash_attention_forward():
|
|||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
|
||||
attn_type = AttnMaskType.paddedcausal
|
||||
if not torch.all(flash_attention_mask):
|
||||
attn_type = AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_type = AttnMaskType.causal
|
||||
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
|
||||
|
|
|
@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non
|
|||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
tensor = tensor.data
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
|
Loading…
Reference in New Issue