mirror of https://github.com/hpcaitech/ColossalAI
[Kernels]Update triton kernels into 2.1.0 (#5046)
* update flash-context-attention * adding kernels * fix * reset * add build script * add building process * add llama2 exmaple * add colossal-llama2 test * clean * fall back test setting * fix test file * clean * clean * clean --------- Co-authored-by: cuiqing.li <lixx336@gmail.com>pull/5060/head
parent
43ad0d9ef0
commit
28052a71fb
|
@ -69,11 +69,11 @@ cd lightllm
|
|||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||
pip3 install -e .
|
||||
|
||||
# also, install xformers from source:
|
||||
pip install ninja
|
||||
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
|
||||
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
||||
|
||||
# install flash-attention
|
||||
git clone -recursive https://github.com/Dao-AILab/flash-attention
|
||||
cd flash-attention
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Docker
|
||||
|
@ -95,10 +95,11 @@ cd lightllm
|
|||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||
pip3 install -e .
|
||||
|
||||
# install xformers from source
|
||||
pip install ninja
|
||||
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
|
||||
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
||||
# install flash-attention
|
||||
git clone -recursive https://github.com/Dao-AILab/flash-attention
|
||||
cd flash-attention
|
||||
pip install -e .
|
||||
|
||||
```
|
||||
|
||||
### Dive into fast-inference!
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# install triton
|
||||
pip install triton
|
||||
pip install transformers
|
||||
|
||||
# install lightllm and flash-attention
|
||||
mkdir 3rdParty
|
||||
cd 3rdParty
|
||||
git clone https://github.com/ModelTC/lightllm
|
||||
cd lightllm
|
||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||
pip install -e .
|
||||
cd ..
|
||||
|
||||
git clone -recursive https://github.com/Dao-AILab/flash-attention
|
||||
cd flash-attention
|
||||
pip install -e .
|
||||
|
||||
cd ../../
|
||||
|
||||
|
||||
|
||||
|
|
@ -8,15 +8,10 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
|||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
try:
|
||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_context_attention_fwd,
|
||||
context_attention_fwd as lightllm_llama_context_attention_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||
|
||||
|
@ -56,32 +51,20 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|||
def llama_triton_context_attention(
|
||||
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
|
||||
):
|
||||
if num_key_value_groups == 1:
|
||||
if HAS_LIGHTLLM_KERNEL is False:
|
||||
llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
else:
|
||||
lightllm_context_attention_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
# if num_key_value_groups == 1:
|
||||
if HAS_LIGHTLLM_KERNEL is False:
|
||||
llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
else:
|
||||
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
|
||||
lightllm_llama2_context_attention_fwd(
|
||||
lightllm_llama_context_attention_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
|
@ -107,6 +90,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
|
|||
# infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.max_len_in_batch,
|
||||
)
|
||||
|
||||
else:
|
||||
Llama2TokenAttentionForwards.token_attn(
|
||||
query_states,
|
||||
|
|
|
@ -15,127 +15,223 @@ if HAS_TRITON:
|
|||
this function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
"""
|
||||
if triton.__version__ < "2.1.0":
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
TMP,
|
||||
alibi_ptr,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_tmp_b,
|
||||
stride_tmp_h,
|
||||
stride_tmp_s,
|
||||
# suggtest set-up 64, 128, 256, 512
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
batch_id = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
TMP,
|
||||
alibi_ptr,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_tmp_b,
|
||||
stride_tmp_h,
|
||||
stride_tmp_s,
|
||||
# suggtest set-up 64, 128, 256, 512
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
batch_id = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# get batch info
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
||||
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# get batch info
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
||||
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
load_p_ptrs = (
|
||||
Q
|
||||
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
||||
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_m = tl.load(alibi_ptr + cur_head)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
load_p_ptrs = (
|
||||
Q
|
||||
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
||||
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||
qk -= alibi_loc * alibi_m
|
||||
alibi_m = tl.load(alibi_ptr + cur_head)
|
||||
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs)
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||
qk -= alibi_loc * alibi_m
|
||||
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs)
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_o = (
|
||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
else:
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel_2(
|
||||
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
|
||||
Out,
|
||||
kv_group_num,
|
||||
stride_qbs, stride_qh, stride_qd,
|
||||
stride_kbs, stride_kh, stride_kd,
|
||||
stride_vbs, stride_vh, stride_vd,
|
||||
stride_obs, stride_oh, stride_od,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
if kv_group_num is not None:
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
off_o = (
|
||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
||||
if kv_group_num is None or kv_group_num == 1:
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
else:
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
|
||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
if Alibi is not None:
|
||||
alibi_m = tl.load(Alibi + cur_head)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if Alibi is not None:
|
||||
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||
qk -= alibi_loc * alibi_m
|
||||
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
|
||||
|
@ -152,10 +248,9 @@ if HAS_TRITON:
|
|||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
|
||||
if triton.__version__ < "2.1.0":
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
_context_flash_attention_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
|
@ -189,7 +284,28 @@ if HAS_TRITON:
|
|||
num_stages=1,
|
||||
)
|
||||
else:
|
||||
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0")
|
||||
_context_flash_attention_kernel_2[grid](
|
||||
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
|
||||
o,
|
||||
None,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
@ -220,7 +336,7 @@ if HAS_TRITON:
|
|||
b_start_loc,
|
||||
b_seq_len,
|
||||
tmp,
|
||||
None,
|
||||
None,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
|
@ -244,6 +360,33 @@ if HAS_TRITON:
|
|||
num_stages=1,
|
||||
)
|
||||
else:
|
||||
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0")
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
_context_flash_attention_kernel_2[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
None,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
kv_group_num,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,)
|
||||
|
||||
return
|
|
@ -13,17 +13,7 @@ except ImportError:
|
|||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
try:
|
||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
|
||||
token_att_fwd as lightllm_llama2_token_att_fwd,
|
||||
)
|
||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import (
|
||||
token_att_fwd2 as lightllm_llama2_token_att_fwd2,
|
||||
)
|
||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
|
||||
token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
|
||||
)
|
||||
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
||||
|
@ -72,7 +62,7 @@ if HAS_TRITON:
|
|||
|
||||
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||
att_m_tensor = None
|
||||
lightllm_llama_token_att_fw2(
|
||||
lightllm_llama_token_att_fwd2(
|
||||
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
|
||||
)
|
||||
prob = None
|
||||
|
@ -203,7 +193,7 @@ class Llama2TokenAttentionForwards:
|
|||
calcu_shape1 = (batch_size, head_num, head_dim)
|
||||
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
||||
|
||||
lightllm_llama2_token_att_fwd(
|
||||
lightllm_llama_token_att_fwd(
|
||||
q,
|
||||
k,
|
||||
att_m_tensor,
|
||||
|
@ -215,12 +205,12 @@ class Llama2TokenAttentionForwards:
|
|||
|
||||
if triton.__version__ == "2.0.0":
|
||||
prob = torch.empty_like(att_m_tensor)
|
||||
lightllm_llama2_token_softmax_fwd(
|
||||
lightllm_llama_token_softmax_fwd(
|
||||
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
||||
)
|
||||
att_m_tensor = None
|
||||
|
||||
lightllm_llama2_token_att_fwd2(
|
||||
lightllm_llama_token_att_fwd2(
|
||||
prob,
|
||||
v,
|
||||
attn_out.view(calcu_shape1),
|
||||
|
|
|
@ -28,7 +28,6 @@ def run_llama_test(args):
|
|||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
||||
model = model.half()
|
||||
model.config
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import argparse
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
TPSIZE = 1
|
||||
BATCH_SIZE = 4
|
||||
MAX_INPUT_LEN = 32
|
||||
MAX_OUTPUT_LEN = 128
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
||||
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': TPSIZE,
|
||||
}])
|
||||
def run_llama_test(test_config, args):
|
||||
|
||||
model_path = args.path
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
|
||||
model = model.half()
|
||||
|
||||
text = ["Introduce London.", "What is the genus of Poodle?"]
|
||||
input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True)
|
||||
|
||||
print(input_ids)
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
|
||||
extra_kwargs={"inference_only": True})
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
outputs = infer_engine.generate(input_ids, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
for o in outputs:
|
||||
output_text = tokenizer.decode(o)
|
||||
print(output_text)
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test(args=args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama(args):
|
||||
spawn(check_llama, args.tp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p", "--path", type=str, default = "hpcai-tech/Colossal-LLaMA-2-7b-base", help="Model path")
|
||||
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size")
|
||||
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
||||
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||
parser.add_argument(
|
||||
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
|
||||
)
|
||||
args = parser.parse_args()
|
||||
test_llama(args)
|
|
@ -12,7 +12,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package
|
|||
torchrec==0.2.0
|
||||
contexttimer
|
||||
einops
|
||||
triton==2.0.0.dev20221202
|
||||
triton==2.1.0
|
||||
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
||||
SentencePiece
|
||||
ninja
|
||||
|
|
|
@ -41,7 +41,6 @@ def test_llama_context_attention():
|
|||
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
|
||||
|
||||
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
||||
|
||||
assert torch.allclose(
|
||||
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3
|
||||
), "outputs from triton and torch are not matched"
|
||||
|
|
Loading…
Reference in New Issue