ColossalAI/colossalai/kernel/triton/token_attention_kernel.py

334 lines
14 KiB
Python
Raw Normal View History

[Feature] The first PR to Add TP inference engine, kv-cache manager and related kernels for our inference system (#4577) * [infer] Infer/llama demo (#4503) * add * add infer example * finish * finish * stash * fix * [Kernels] add inference token attention kernel (#4505) * add token forward * fix tests * fix comments * add try import triton * add adapted license * add tests check * [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485) * added _vllm_rms_norm * change place * added tests * added tests * modify * adding kernels * added tests: * adding kernels * modify * added * updating kernels * adding tests * added tests * kernel change * submit * modify * added * edit comments * change name * change commnets and fix import * add * added * combine codes (#4509) * [feature] add KV cache manager for llama & bloom inference (#4495) * add kv cache memory manager * add stateinfo during inference * format * format * rename file * add kv cache test * revise on BatchInferState * file dir change * [Bug FIx] import llama context ops fix (#4524) * added _vllm_rms_norm * change place * added tests * added tests * modify * adding kernels * added tests: * adding kernels * modify * added * updating kernels * adding tests * added tests * kernel change * submit * modify * added * edit comments * change name * change commnets and fix import * add * added * fix * add ops into init.py * add * [Infer] Add TPInferEngine and fix file path (#4532) * add engine for TP inference * move file path * update path * fix TPInferEngine * remove unused file * add engine test demo * revise TPInferEngine * fix TPInferEngine, add test * fix * Add Inference test for llama (#4508) * add kv cache memory manager * add stateinfo during inference * add * add infer example * finish * finish * format * format * rename file * add kv cache test * revise on BatchInferState * add inference test for llama * fix conflict * feature: add some new features for llama engine * adapt colossalai triton interface * Change the parent class of llama policy * add nvtx * move llama inference code to tensor_parallel * fix __init__.py * rm tensor_parallel * fix: fix bugs in auto_policy.py * fix:rm some unused codes * mv colossalai/tpinference to colossalai/inference/tensor_parallel * change __init__.py * save change * fix engine * Bug fix: Fix hang * remove llama_infer_engine.py --------- Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [infer] Add Bloom inference policy and replaced methods (#4512) * add bloom inference methods and policy * enable pass BatchInferState from model forward * revise bloom infer layers/policies * add engine for inference (draft) * add test for bloom infer * fix bloom infer policy and flow * revise bloom test * fix bloom file path * remove unused codes * fix bloom modeling * fix dir typo * fix trivial * fix policy * clean pr * trivial fix * Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552) This reverts commit 17cfa5714083a81a505c097f1c411cd28162d922. * [Doc] Add colossal inference doc (#4549) * create readme * add readme.md * fix typos * [infer] Add Bloom inference policy and replaced methods (#4553) * add bloom inference methods and policy * enable pass BatchInferState from model forward * revise bloom infer layers/policies * add engine for inference (draft) * add test for bloom infer * fix bloom infer policy and flow * revise bloom test * fix bloom file path * remove unused codes * fix bloom modeling * fix dir typo * fix trivial * fix policy * clean pr * trivial fix * trivial * Fix Bugs In Llama Model Forward (#4550) * add kv cache memory manager * add stateinfo during inference * add * add infer example * finish * finish * format * format * rename file * add kv cache test * revise on BatchInferState * add inference test for llama * fix conflict * feature: add some new features for llama engine * adapt colossalai triton interface * Change the parent class of llama policy * add nvtx * move llama inference code to tensor_parallel * fix __init__.py * rm tensor_parallel * fix: fix bugs in auto_policy.py * fix:rm some unused codes * mv colossalai/tpinference to colossalai/inference/tensor_parallel * change __init__.py * save change * fix engine * Bug fix: Fix hang * remove llama_infer_engine.py * bug fix: fix bugs about infer_state.is_context_stage * remove pollcies * fix: delete unused code * fix: delete unused code * remove unused coda * fix conflict --------- Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [doc] add colossal inference fig (#4554) * create readme * add readme.md * fix typos * upload fig * [NFC] fix docstring for colossal inference (#4555) Fix docstring and comments in kv cache manager and bloom modeling * fix docstring in llama modeling (#4557) * [Infer] check import vllm (#4559) * change import vllm * import apply_rotary_pos_emb * change import location * [DOC] add installation req (#4561) * add installation req * fix * slight change * remove empty * [Feature] rms-norm transfer into inference llama.py (#4563) * add installation req * fix * slight change * remove empty * add rmsnorm polciy * add * clean codes * [infer] Fix tp inference engine (#4564) * fix engine prepare data * add engine test * use bloom for testing * revise on test * revise on test * reset shardformer llama (#4569) * [infer] Fix engine - tensors on different devices (#4570) * fix diff device in engine * [codefactor] Feature/colossal inference (#4579) * code factors * remove * change coding (#4581) * [doc] complete README of colossal inference (#4585) * complete fig * Update README.md * [doc]update readme (#4586) * update readme * Update README.md * bug fix: fix bus in llama and bloom (#4588) * [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward (#4592) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * [Kernel]Rmsnorm fix (#4598) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * add triton rmsnorm * delete vllm kernel flag * [Bug Fix]Fix bugs in llama (#4601) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * bug fix: remove rotary_positions_ids --------- Co-authored-by: cuiqing.li <lixx3527@gmail.com> * [kernel] Add triton layer norm & replace norm for bloom (#4609) * add layernorm for inference * add test for layernorm kernel * add bloom layernorm replacement policy * trivial: path * [Infer] Bug fix rotary embedding in llama (#4608) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * [bench] Add bloom inference benchmark (#4621) * add bloom benchmark * readme - update benchmark res * trivial - uncomment for testing (#4622) * [Infer] add check triton and cuda version for tests (#4627) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * add check triton and cuda * Update sharder.py (#4629) * [Inference] Hot fix some bugs and typos (#4632) * fix * fix test * fix conflicts * [typo]Comments fix (#4633) * fallback * fix commnets * bug fix: fix some bugs in test_llama and test_bloom (#4635) * [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * add check triton and cuda * delete benchmark and fix infer bugs * delete benchmark for tests * delete useless code * delete bechmark function in utils * [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642) * [Fix] revise TPInferEngine methods and inference tests * fix llama/bloom infer benchmarks * fix infer tests * trivial fix: benchmakrs * trivial * trivial: rm print * modify utils filename for infer ops test (#4657) * [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670) * fix engine funcs * TPInferEngine: receive shard config in init * benchmarks: revise TPInferEngine init * benchmarks: remove pytest decorator * trivial fix * use small model for tests * [NFC] use args for infer benchmarks (#4674) * revise infer default (#4683) * [Fix] optimize/shard model in TPInferEngine init (#4684) * remove using orig model in engine * revise inference tests * trivial: rename --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
2023-09-11 17:22:56 +00:00
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import math
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
@triton.jit
def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride,
q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride,
attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = max_kv_cache_len
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return
@triton.jit
def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen,
max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride,
q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride,
k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = max_kv_cache_len
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
alibi_m = tl.load(alibi + current_head)
q = tl.load(Q + off_q + start_mark)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return
@torch.no_grad()
def token_attn_fwd_1(q,
k,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
alibi=None):
BLOCK = 32
# shape constraints
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
assert q_head_dim == k_head_dim
assert k_head_dim in {16, 32, 64, 128}
sm_scale = 1.0 / (k_head_dim**0.5)
batch, head_num = kv_cache_loc.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
num_warps = 4 if k_head_dim <= 64 else 8
num_warps = 2
if alibi is not None:
_token_attn_1_alibi_kernel[grid](
q,
k,
sm_scale,
alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
attn_out.stride(0),
attn_out.stride(1),
HEAD_DIM=k_head_dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_token_attn_1_kernel[grid](
q,
k,
sm_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
attn_out.stride(0),
attn_out.stride(1),
HEAD_DIM=k_head_dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@triton.jit
def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out,
logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride,
BLOCK_SIZE: tl.constexpr):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
row = tl.load(softmax_logics + current_head * logics_head_dim_stride +
(current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
mask=col_offsets < current_batch_seq_len,
other=-float('inf')).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(softmax_prob_out + current_head * prob_head_dim_stride +
(current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
softmax_output,
mask=col_offsets < current_batch_seq_len)
return
@torch.no_grad()
def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
_token_attn_softmax_fwd[(batch, head_num)](
softmax_logics,
kv_cache_start_loc,
kv_cache_seqlen,
softmax_prob_out,
softmax_logics.stride(0),
softmax_logics.stride(1),
softmax_prob_out.stride(0),
softmax_prob_out.stride(1),
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return
@triton.jit
def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride,
v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride,
attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = current_batch_seq_len
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for start_n in range(0, current_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0)
v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0)
v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
other=0.0)
acc += tl.sum(p_value[:, None] * v_value, 0)
acc = acc.to(tl.float16)
off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride
out_ptrs = attn_out + off_o
tl.store(out_ptrs, acc)
return
@torch.no_grad()
def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len):
if triton.__version__ >= "2.1.0":
BLOCK = 128
else:
BLOCK = 64
batch, head = kv_cache_loc.shape[0], v.shape[1]
grid = (batch, head)
num_warps = 4
dim = v.shape[-1]
_token_attn_2_kernel[grid](
prob,
v,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
prob.stride(0),
prob.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
attn_out.stride(0),
attn_out.stride(1),
attn_out.stride(2),
HEAD_DIM=dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@torch.no_grad()
def token_attention_fwd(q,
k,
v,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=None):
head_num = k.shape[1]
batch_size = kv_cache_seq_len.shape[0]
calcu_shape1 = (batch_size, head_num, k.shape[2])
total_token_num = k.shape[0]
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
token_attn_fwd_1(q.view(calcu_shape1),
k,
att_m_tensor,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=alibi)
prob = torch.empty_like(att_m_tensor)
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None
token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len,
max_len_in_batch)
prob = None
return