mirror of https://github.com/InternLM/InternLM
Fit to flash attention 1.0
parent
78353e12cf
commit
a35ce4c888
|
@ -8,7 +8,7 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn import flash_attn_varlen_kvpacked_func
|
from flash_attn import flash_attn_unpadded_kvpacked_func
|
||||||
from flash_attn.modules.mha import (
|
from flash_attn.modules.mha import (
|
||||||
CrossAttention,
|
CrossAttention,
|
||||||
FlashCrossAttention,
|
FlashCrossAttention,
|
||||||
|
@ -280,7 +280,7 @@ class MHA(nn.Module):
|
||||||
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
|
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
total_kv = total_kv.to(torch.bfloat16)
|
total_kv = total_kv.to(torch.bfloat16)
|
||||||
|
|
||||||
output = flash_attn_varlen_kvpacked_func(
|
output = flash_attn_unpadded_kvpacked_func(
|
||||||
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
|
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
|
||||||
).to(x.dtype)
|
).to(x.dtype)
|
||||||
|
|
||||||
|
@ -294,18 +294,17 @@ class MHA(nn.Module):
|
||||||
k = k.squeeze(2)
|
k = k.squeeze(2)
|
||||||
v = v.squeeze(2)
|
v = v.squeeze(2)
|
||||||
sp = k.shape
|
sp = k.shape
|
||||||
expansion = q.size(2) // k.size(2)
|
|
||||||
scores = torch.einsum(
|
scores = torch.einsum(
|
||||||
"blhd,bnhd->bhln",
|
"blhd,bnhd->bhln",
|
||||||
q,
|
q,
|
||||||
k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
|
k.reshape(sp[0], sp[1], q.size(2), sp[3]),
|
||||||
) / math.sqrt(q.size(-1))
|
) / math.sqrt(q.size(-1))
|
||||||
scores = scores.masked_fill(attn_mask, -65000.0)
|
scores = scores.masked_fill(attn_mask, -65000.0)
|
||||||
scores = F.softmax(scores, dim=-1) # bsz x h x L x L
|
scores = F.softmax(scores, dim=-1) # bsz x h x L x L
|
||||||
context = torch.einsum(
|
context = torch.einsum(
|
||||||
"bhmn,bnhd->bmhd",
|
"bhmn,bnhd->bmhd",
|
||||||
scores,
|
scores,
|
||||||
v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
|
v.reshape(sp[0], sp[1], q.size(2), sp[3]),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context = self.inner_cross_attn(q, kv, causal=True)
|
context = self.inner_cross_attn(q, kv, causal=True)
|
||||||
|
|
Loading…
Reference in New Issue