mirror of https://github.com/hpcaitech/ColossalAI
[Kernels]Update triton kernels into 2.1.0 (#5046)
* update flash-context-attention * adding kernels * fix * reset * add build script * add building process * add llama2 exmaple * add colossal-llama2 test * clean * fall back test setting * fix test file * clean * clean * clean --------- Co-authored-by: cuiqing.li <lixx336@gmail.com>pull/5060/head
parent
43ad0d9ef0
commit
28052a71fb
|
@ -69,11 +69,11 @@ cd lightllm
|
||||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||||
pip3 install -e .
|
pip3 install -e .
|
||||||
|
|
||||||
# also, install xformers from source:
|
|
||||||
pip install ninja
|
|
||||||
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
|
|
||||||
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
|
||||||
|
|
||||||
|
# install flash-attention
|
||||||
|
git clone -recursive https://github.com/Dao-AILab/flash-attention
|
||||||
|
cd flash-attention
|
||||||
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### Docker
|
### Docker
|
||||||
|
@ -95,10 +95,11 @@ cd lightllm
|
||||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||||
pip3 install -e .
|
pip3 install -e .
|
||||||
|
|
||||||
# install xformers from source
|
# install flash-attention
|
||||||
pip install ninja
|
git clone -recursive https://github.com/Dao-AILab/flash-attention
|
||||||
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
|
cd flash-attention
|
||||||
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
pip install -e .
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Dive into fast-inference!
|
### Dive into fast-inference!
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# install triton
|
||||||
|
pip install triton
|
||||||
|
pip install transformers
|
||||||
|
|
||||||
|
# install lightllm and flash-attention
|
||||||
|
mkdir 3rdParty
|
||||||
|
cd 3rdParty
|
||||||
|
git clone https://github.com/ModelTC/lightllm
|
||||||
|
cd lightllm
|
||||||
|
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||||
|
pip install -e .
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
git clone -recursive https://github.com/Dao-AILab/flash-attention
|
||||||
|
cd flash-attention
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
cd ../../
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,15 +8,10 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||||
|
|
||||||
from ._utils import copy_kv_to_mem_cache
|
from ._utils import copy_kv_to_mem_cache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
|
||||||
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
|
||||||
)
|
|
||||||
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
||||||
context_attention_fwd as lightllm_context_attention_fwd,
|
context_attention_fwd as lightllm_llama_context_attention_fwd,
|
||||||
)
|
)
|
||||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||||
|
|
||||||
|
@ -56,32 +51,20 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
def llama_triton_context_attention(
|
def llama_triton_context_attention(
|
||||||
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
|
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
|
||||||
):
|
):
|
||||||
if num_key_value_groups == 1:
|
# if num_key_value_groups == 1:
|
||||||
if HAS_LIGHTLLM_KERNEL is False:
|
if HAS_LIGHTLLM_KERNEL is False:
|
||||||
llama_context_attn_fwd(
|
llama_context_attn_fwd(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attn_output,
|
attn_output,
|
||||||
infer_state.start_loc,
|
infer_state.start_loc,
|
||||||
infer_state.seq_len,
|
infer_state.seq_len,
|
||||||
# infer_state.cache_manager.past_key_values_length,
|
# infer_state.cache_manager.past_key_values_length,
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
lightllm_context_attention_fwd(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attn_output,
|
|
||||||
infer_state.start_loc,
|
|
||||||
infer_state.seq_len,
|
|
||||||
# infer_state.cache_manager.past_key_values_length,
|
|
||||||
infer_state.max_len_in_batch,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
|
lightllm_llama_context_attention_fwd(
|
||||||
lightllm_llama2_context_attention_fwd(
|
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
@ -107,6 +90,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
|
||||||
# infer_state.cache_manager.past_key_values_length,
|
# infer_state.cache_manager.past_key_values_length,
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
Llama2TokenAttentionForwards.token_attn(
|
Llama2TokenAttentionForwards.token_attn(
|
||||||
query_states,
|
query_states,
|
||||||
|
|
|
@ -15,127 +15,223 @@ if HAS_TRITON:
|
||||||
this function is modified from
|
this function is modified from
|
||||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||||
"""
|
"""
|
||||||
|
if triton.__version__ < "2.1.0":
|
||||||
|
@triton.jit
|
||||||
|
def _context_flash_attention_kernel(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
sm_scale,
|
||||||
|
B_Start_Loc,
|
||||||
|
B_Seqlen,
|
||||||
|
TMP,
|
||||||
|
alibi_ptr,
|
||||||
|
Out,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_qd,
|
||||||
|
stride_kbs,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
stride_tmp_b,
|
||||||
|
stride_tmp_h,
|
||||||
|
stride_tmp_s,
|
||||||
|
# suggtest set-up 64, 128, 256, 512
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
batch_id = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
@triton.jit
|
# initialize offsets
|
||||||
def _context_flash_attention_kernel(
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
Q,
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
K,
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
V,
|
|
||||||
sm_scale,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
TMP,
|
|
||||||
alibi_ptr,
|
|
||||||
Out,
|
|
||||||
stride_qbs,
|
|
||||||
stride_qh,
|
|
||||||
stride_qd,
|
|
||||||
stride_kbs,
|
|
||||||
stride_kh,
|
|
||||||
stride_kd,
|
|
||||||
stride_vbs,
|
|
||||||
stride_vh,
|
|
||||||
stride_vd,
|
|
||||||
stride_obs,
|
|
||||||
stride_oh,
|
|
||||||
stride_od,
|
|
||||||
stride_tmp_b,
|
|
||||||
stride_tmp_h,
|
|
||||||
stride_tmp_s,
|
|
||||||
# suggtest set-up 64, 128, 256, 512
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
batch_id = tl.program_id(0)
|
|
||||||
cur_head = tl.program_id(1)
|
|
||||||
start_m = tl.program_id(2)
|
|
||||||
|
|
||||||
# initialize offsets
|
# get batch info
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
# get batch info
|
load_p_ptrs = (
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
Q
|
||||||
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
|
||||||
block_start_loc = BLOCK_M * start_m
|
+ cur_head * stride_qh
|
||||||
|
+ offs_d[None, :] * stride_qd
|
||||||
load_p_ptrs = (
|
|
||||||
Q
|
|
||||||
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
|
|
||||||
+ cur_head * stride_qh
|
|
||||||
+ offs_d[None, :] * stride_qd
|
|
||||||
)
|
|
||||||
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
|
||||||
|
|
||||||
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
|
||||||
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
|
||||||
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
|
||||||
|
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
||||||
|
|
||||||
if alibi_ptr is not None:
|
|
||||||
alibi_m = tl.load(alibi_ptr + cur_head)
|
|
||||||
|
|
||||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
|
||||||
|
|
||||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
k = tl.load(
|
|
||||||
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
)
|
||||||
|
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||||
qk += tl.dot(q, k)
|
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||||
qk *= sm_scale
|
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
||||||
|
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
|
|
||||||
if alibi_ptr is not None:
|
if alibi_ptr is not None:
|
||||||
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
alibi_m = tl.load(alibi_ptr + cur_head)
|
||||||
qk -= alibi_loc * alibi_m
|
|
||||||
|
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||||
|
|
||||||
m_ij = tl.max(qk, 1)
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
p = tl.exp(qk - m_ij[:, None])
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
l_ij = tl.sum(p, 1)
|
k = tl.load(
|
||||||
# -- update m_i and l_i
|
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||||
m_i_new = tl.maximum(m_i, m_ij)
|
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||||
alpha = tl.exp(m_i - m_i_new)
|
other=0.0,
|
||||||
beta = tl.exp(m_ij - m_i_new)
|
)
|
||||||
l_i_new = alpha * l_i + beta * l_ij
|
|
||||||
# -- update output accumulator --
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
# scale p
|
qk += tl.dot(q, k)
|
||||||
p_scale = beta / l_i_new
|
qk *= sm_scale
|
||||||
p = p * p_scale[:, None]
|
|
||||||
# scale acc
|
if alibi_ptr is not None:
|
||||||
acc_scale = l_i / l_i_new * alpha
|
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||||
tl.store(t_ptrs, acc_scale)
|
qk -= alibi_loc * alibi_m
|
||||||
acc_scale = tl.load(t_ptrs)
|
|
||||||
acc = acc * acc_scale[:, None]
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||||
# update acc
|
|
||||||
v = tl.load(
|
m_ij = tl.max(qk, 1)
|
||||||
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
p = tl.exp(qk - m_ij[:, None])
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
l_ij = tl.sum(p, 1)
|
||||||
other=0.0,
|
# -- update m_i and l_i
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
alpha = tl.exp(m_i - m_i_new)
|
||||||
|
beta = tl.exp(m_ij - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + beta * l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
p_scale = beta / l_i_new
|
||||||
|
p = p * p_scale[:, None]
|
||||||
|
# scale acc
|
||||||
|
acc_scale = l_i / l_i_new * alpha
|
||||||
|
tl.store(t_ptrs, acc_scale)
|
||||||
|
acc_scale = tl.load(t_ptrs)
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||||
|
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc += tl.dot(p, v)
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
|
||||||
|
off_o = (
|
||||||
|
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||||
)
|
)
|
||||||
|
out_ptrs = Out + off_o
|
||||||
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
@triton.jit
|
||||||
|
def _context_flash_attention_kernel_2(
|
||||||
|
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
|
||||||
|
Out,
|
||||||
|
kv_group_num,
|
||||||
|
stride_qbs, stride_qh, stride_qd,
|
||||||
|
stride_kbs, stride_kh, stride_kd,
|
||||||
|
stride_vbs, stride_vh, stride_vd,
|
||||||
|
stride_obs, stride_oh, stride_od,
|
||||||
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_batch = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
if kv_group_num is not None:
|
||||||
|
cur_kv_head = cur_head // kv_group_num
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
acc += tl.dot(p, v)
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
# update m_i and l_i
|
|
||||||
l_i = l_i_new
|
|
||||||
m_i = m_i_new
|
|
||||||
|
|
||||||
off_o = (
|
block_start_loc = BLOCK_M * start_m
|
||||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
|
||||||
)
|
# initialize offsets
|
||||||
out_ptrs = Out + off_o
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
return
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
||||||
|
if kv_group_num is None or kv_group_num == 1:
|
||||||
|
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||||
|
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||||
|
else:
|
||||||
|
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
|
||||||
|
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
||||||
|
|
||||||
|
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||||
|
|
||||||
|
k_ptrs = K + off_k
|
||||||
|
v_ptrs = V + off_v
|
||||||
|
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
|
|
||||||
|
if Alibi is not None:
|
||||||
|
alibi_m = tl.load(Alibi + cur_head)
|
||||||
|
|
||||||
|
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||||
|
|
||||||
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
# -- compute qk ----
|
||||||
|
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
|
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk += tl.dot(q, k)
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if Alibi is not None:
|
||||||
|
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||||
|
qk -= alibi_loc * alibi_m
|
||||||
|
|
||||||
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||||
|
|
||||||
|
m_ij = tl.max(qk, 1)
|
||||||
|
p = tl.exp(qk - m_ij[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
alpha = tl.exp(m_i - m_i_new)
|
||||||
|
beta = tl.exp(m_ij - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + beta * l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
p_scale = beta / l_i_new
|
||||||
|
p = p * p_scale[:, None]
|
||||||
|
# scale acc
|
||||||
|
acc_scale = l_i / l_i_new * alpha
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
|
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
||||||
|
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc += tl.dot(p, v)
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
# initialize pointers to output
|
||||||
|
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||||
|
out_ptrs = Out + off_o
|
||||||
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||||
|
return
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
|
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
|
||||||
|
@ -152,10 +248,9 @@ if HAS_TRITON:
|
||||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||||
|
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
if triton.__version__ < "2.1.0":
|
if triton.__version__ < "2.1.0":
|
||||||
|
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||||
_context_flash_attention_kernel[grid](
|
_context_flash_attention_kernel[grid](
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
@ -189,7 +284,28 @@ if HAS_TRITON:
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0")
|
_context_flash_attention_kernel_2[grid](
|
||||||
|
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
|
||||||
|
o,
|
||||||
|
None,
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(2),
|
||||||
|
BLOCK_M=BLOCK,
|
||||||
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -220,7 +336,7 @@ if HAS_TRITON:
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
tmp,
|
tmp,
|
||||||
None,
|
None,
|
||||||
o,
|
o,
|
||||||
q.stride(0),
|
q.stride(0),
|
||||||
q.stride(1),
|
q.stride(1),
|
||||||
|
@ -244,6 +360,33 @@ if HAS_TRITON:
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0")
|
kv_group_num = q.shape[1] // k.shape[1]
|
||||||
|
_context_flash_attention_kernel_2[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
sm_scale,
|
||||||
|
None,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
o,
|
||||||
|
kv_group_num,
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(2),
|
||||||
|
BLOCK_M=BLOCK,
|
||||||
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,)
|
||||||
|
|
||||||
return
|
return
|
|
@ -13,17 +13,7 @@ except ImportError:
|
||||||
print("please install triton from https://github.com/openai/triton")
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
|
||||||
token_att_fwd as lightllm_llama2_token_att_fwd,
|
|
||||||
)
|
|
||||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import (
|
|
||||||
token_att_fwd2 as lightllm_llama2_token_att_fwd2,
|
|
||||||
)
|
|
||||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
|
|
||||||
token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
|
|
||||||
)
|
|
||||||
|
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2
|
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
||||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
||||||
|
@ -72,7 +62,7 @@ if HAS_TRITON:
|
||||||
|
|
||||||
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||||
att_m_tensor = None
|
att_m_tensor = None
|
||||||
lightllm_llama_token_att_fw2(
|
lightllm_llama_token_att_fwd2(
|
||||||
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
|
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
|
||||||
)
|
)
|
||||||
prob = None
|
prob = None
|
||||||
|
@ -203,7 +193,7 @@ class Llama2TokenAttentionForwards:
|
||||||
calcu_shape1 = (batch_size, head_num, head_dim)
|
calcu_shape1 = (batch_size, head_num, head_dim)
|
||||||
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
||||||
|
|
||||||
lightllm_llama2_token_att_fwd(
|
lightllm_llama_token_att_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
att_m_tensor,
|
att_m_tensor,
|
||||||
|
@ -215,12 +205,12 @@ class Llama2TokenAttentionForwards:
|
||||||
|
|
||||||
if triton.__version__ == "2.0.0":
|
if triton.__version__ == "2.0.0":
|
||||||
prob = torch.empty_like(att_m_tensor)
|
prob = torch.empty_like(att_m_tensor)
|
||||||
lightllm_llama2_token_softmax_fwd(
|
lightllm_llama_token_softmax_fwd(
|
||||||
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
||||||
)
|
)
|
||||||
att_m_tensor = None
|
att_m_tensor = None
|
||||||
|
|
||||||
lightllm_llama2_token_att_fwd2(
|
lightllm_llama_token_att_fwd2(
|
||||||
prob,
|
prob,
|
||||||
v,
|
v,
|
||||||
attn_out.view(calcu_shape1),
|
attn_out.view(calcu_shape1),
|
||||||
|
|
|
@ -28,7 +28,6 @@ def run_llama_test(args):
|
||||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model.config
|
|
||||||
|
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import argparse
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer import ShardConfig
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||||
|
TPSIZE = 1
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
MAX_INPUT_LEN = 32
|
||||||
|
MAX_OUTPUT_LEN = 128
|
||||||
|
|
||||||
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('test_config', [{
|
||||||
|
'tp_size': TPSIZE,
|
||||||
|
}])
|
||||||
|
def run_llama_test(test_config, args):
|
||||||
|
|
||||||
|
model_path = args.path
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
text = ["Introduce London.", "What is the genus of Poodle?"]
|
||||||
|
input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True)
|
||||||
|
|
||||||
|
print(input_ids)
|
||||||
|
|
||||||
|
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
|
||||||
|
extra_kwargs={"inference_only": True})
|
||||||
|
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
|
|
||||||
|
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||||
|
outputs = infer_engine.generate(input_ids, **generate_kwargs)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
for o in outputs:
|
||||||
|
output_text = tokenizer.decode(o)
|
||||||
|
print(output_text)
|
||||||
|
|
||||||
|
|
||||||
|
def check_llama(rank, world_size, port, args):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_llama_test(args=args)
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_llama(args):
|
||||||
|
spawn(check_llama, args.tp_size, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-p", "--path", type=str, default = "hpcai-tech/Colossal-LLaMA-2-7b-base", help="Model path")
|
||||||
|
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||||
|
parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size")
|
||||||
|
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
||||||
|
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
test_llama(args)
|
|
@ -12,7 +12,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package
|
||||||
torchrec==0.2.0
|
torchrec==0.2.0
|
||||||
contexttimer
|
contexttimer
|
||||||
einops
|
einops
|
||||||
triton==2.0.0.dev20221202
|
triton==2.1.0
|
||||||
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
||||||
SentencePiece
|
SentencePiece
|
||||||
ninja
|
ninja
|
||||||
|
|
|
@ -41,7 +41,6 @@ def test_llama_context_attention():
|
||||||
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
|
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
|
||||||
|
|
||||||
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
||||||
|
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3
|
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3
|
||||||
), "outputs from triton and torch are not matched"
|
), "outputs from triton and torch are not matched"
|
||||||
|
|
Loading…
Reference in New Issue