mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
1.0 KiB
27 lines
1.0 KiB
import math |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
|
|
|
|
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): |
|
""" |
|
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 |
|
""" |
|
xq = xq.view(bs, seqlen, num_head, head_dim) |
|
xk = xk.view(bs, seqlen, num_head, head_dim) |
|
xv = xv.view(bs, seqlen, num_head, head_dim) |
|
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() |
|
mask[mask == 0.0] = -100000000.0 |
|
mask = mask.repeat(bs, num_head, 1, 1) |
|
keys = xk |
|
values = xv |
|
xq = xq.transpose(1, 2) |
|
keys = keys.transpose(1, 2) |
|
values = values.transpose(1, 2) |
|
sm_scale = 1 / math.sqrt(head_dim) |
|
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale |
|
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) |
|
|
|
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) |
|
return output
|
|
|