mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
1.0 KiB
28 lines
1.0 KiB
import math
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
|
"""
|
|
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
|
|
"""
|
|
xq = xq.view(bs, seqlen, num_head, head_dim)
|
|
xk = xk.view(bs, seqlen, num_head, head_dim)
|
|
xv = xv.view(bs, seqlen, num_head, head_dim)
|
|
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
|
|
mask[mask == 0.0] = -100000000.0
|
|
mask = mask.repeat(bs, num_head, 1, 1)
|
|
keys = xk
|
|
values = xv
|
|
xq = xq.transpose(1, 2)
|
|
keys = keys.transpose(1, 2)
|
|
values = values.transpose(1, 2)
|
|
sm_scale = 1 / math.sqrt(head_dim)
|
|
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
|
|
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
|
|
|
|
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
|
|
return output
|