Browse Source

[shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)

* fix flash attn

* fix

fix
pull/5099/head
flybird11111 1 year ago committed by GitHub
parent
commit
aae496631c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/shardformer/modeling/chatglm2.py
  2. 9
      colossalai/shardformer/modeling/gpt2.py
  3. 3
      colossalai/shardformer/modeling/llama.py
  4. 3
      colossalai/shardformer/modeling/opt.py
  5. 5
      colossalai/shardformer/modeling/whisper.py
  6. 1
      examples/language/llama2/pretrain.py

3
colossalai/shardformer/modeling/chatglm2.py

@ -51,7 +51,8 @@ def get_flash_core_attention_forward():
attn_mask_type = AttnMaskType.causal attn_mask_type = AttnMaskType.causal
else: else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention( attention = ColoAttention(
embed_dim=self.hidden_size_per_partition, embed_dim=self.hidden_size_per_partition,

9
colossalai/shardformer/modeling/gpt2.py

@ -771,11 +771,12 @@ def get_gpt2_flash_attention_forward():
attn_mask_type = AttnMaskType.causal attn_mask_type = AttnMaskType.causal
flash_attention_mask = None flash_attention_mask = None
if attention_mask != None: if attention_mask != None:
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
scale = value.size(-1) ** -0.5 scale = value.size(-1) ** -0.5
if self.scale_attn_by_inverse_layer_idx: if self.scale_attn_by_inverse_layer_idx:

3
colossalai/shardformer/modeling/llama.py

@ -465,7 +465,8 @@ def get_llama_flash_attention_forward():
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention( attn_output = attention(

3
colossalai/shardformer/modeling/opt.py

@ -581,7 +581,8 @@ def get_opt_flash_attention_forward():
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
) )
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention( attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling

5
colossalai/shardformer/modeling/whisper.py

@ -106,7 +106,10 @@ def get_whisper_flash_attention_forward():
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
) )
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
attn_type = AttnMaskType.paddedcausal if not torch.all(flash_attention_mask):
attn_type = AttnMaskType.paddedcausal
else:
attn_type = AttnMaskType.causal
attention = ColoAttention( attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling

1
examples/language/llama2/pretrain.py

@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM) dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor = tensor.data
tensor.div_(dist.get_world_size()) tensor.div_(dist.get_world_size())
return tensor return tensor

Loading…
Cancel
Save