[Kernels]Update triton kernels into 2.1.0 (#5046)

* update flash-context-attention

* adding kernels

* fix

* reset

* add build script

* add building process

* add llama2 exmaple

* add colossal-llama2 test

* clean

* fall back test setting

* fix test file

* clean

* clean

* clean

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
pull/5060/head
Cuiqing Li (李崔卿) 2023-11-16 16:43:15 +08:00 committed by GitHub
parent 43ad0d9ef0
commit 28052a71fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 392 additions and 171 deletions

View File

@ -69,11 +69,11 @@ cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039 git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip3 install -e . pip3 install -e .
# also, install xformers from source:
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
# install flash-attention
git clone -recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention
pip install -e .
``` ```
### Docker ### Docker
@ -95,10 +95,11 @@ cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039 git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip3 install -e . pip3 install -e .
# install xformers from source # install flash-attention
pip install ninja git clone -recursive https://github.com/Dao-AILab/flash-attention
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types cd flash-attention
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers pip install -e .
``` ```
### Dive into fast-inference! ### Dive into fast-inference!

View File

@ -0,0 +1,24 @@
#!/usr/bin/env bash
# install triton
pip install triton
pip install transformers
# install lightllm and flash-attention
mkdir 3rdParty
cd 3rdParty
git clone https://github.com/ModelTC/lightllm
cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip install -e .
cd ..
git clone -recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention
pip install -e .
cd ../../

View File

@ -8,15 +8,10 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from ._utils import copy_kv_to_mem_cache from ._utils import copy_kv_to_mem_cache
try: try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_context_attention_fwd, context_attention_fwd as lightllm_llama_context_attention_fwd,
) )
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
@ -56,7 +51,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
def llama_triton_context_attention( def llama_triton_context_attention(
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
): ):
if num_key_value_groups == 1: # if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False: if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd( llama_context_attn_fwd(
query_states, query_states,
@ -69,19 +64,7 @@ def llama_triton_context_attention(
infer_state.max_len_in_batch, infer_state.max_len_in_batch,
) )
else: else:
lightllm_context_attention_fwd( lightllm_llama_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
lightllm_llama2_context_attention_fwd(
query_states, query_states,
key_states, key_states,
value_states, value_states,
@ -107,6 +90,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
# infer_state.cache_manager.past_key_values_length, # infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch, infer_state.max_len_in_batch,
) )
else: else:
Llama2TokenAttentionForwards.token_attn( Llama2TokenAttentionForwards.token_attn(
query_states, query_states,

View File

@ -15,7 +15,7 @@ if HAS_TRITON:
this function is modified from this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
""" """
if triton.__version__ < "2.1.0":
@triton.jit @triton.jit
def _context_flash_attention_kernel( def _context_flash_attention_kernel(
Q, Q,
@ -136,6 +136,102 @@ if HAS_TRITON:
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return return
else:
@triton.jit
def _context_flash_attention_kernel_2(
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
Out,
kv_group_num,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
if kv_group_num is not None:
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
if kv_group_num is None or kv_group_num == 1:
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
else:
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + off_k
v_ptrs = V + off_v
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if Alibi is not None:
alibi_m = tl.load(Alibi + cur_head)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if Alibi is not None:
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
qk -= alibi_loc * alibi_m
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad() @torch.no_grad()
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
@ -153,9 +249,8 @@ if HAS_TRITON:
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
if triton.__version__ < "2.1.0": if triton.__version__ < "2.1.0":
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid]( _context_flash_attention_kernel[grid](
q, q,
k, k,
@ -189,7 +284,28 @@ if HAS_TRITON:
num_stages=1, num_stages=1,
) )
else: else:
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") _context_flash_attention_kernel_2[grid](
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
o,
None,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return return
@ -244,6 +360,33 @@ if HAS_TRITON:
num_stages=1, num_stages=1,
) )
else: else:
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") kv_group_num = q.shape[1] // k.shape[1]
_context_flash_attention_kernel_2[grid](
q,
k,
v,
sm_scale,
None,
b_start_loc,
b_seq_len,
o,
kv_group_num,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,)
return return

View File

@ -13,17 +13,7 @@ except ImportError:
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
try: try:
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
token_att_fwd as lightllm_llama2_token_att_fwd,
)
from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import (
token_att_fwd2 as lightllm_llama2_token_att_fwd2,
)
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
@ -72,7 +62,7 @@ if HAS_TRITON:
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None att_m_tensor = None
lightllm_llama_token_att_fw2( lightllm_llama_token_att_fwd2(
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
) )
prob = None prob = None
@ -203,7 +193,7 @@ class Llama2TokenAttentionForwards:
calcu_shape1 = (batch_size, head_num, head_dim) calcu_shape1 = (batch_size, head_num, head_dim)
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
lightllm_llama2_token_att_fwd( lightllm_llama_token_att_fwd(
q, q,
k, k,
att_m_tensor, att_m_tensor,
@ -215,12 +205,12 @@ class Llama2TokenAttentionForwards:
if triton.__version__ == "2.0.0": if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor) prob = torch.empty_like(att_m_tensor)
lightllm_llama2_token_softmax_fwd( lightllm_llama_token_softmax_fwd(
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
) )
att_m_tensor = None att_m_tensor = None
lightllm_llama2_token_att_fwd2( lightllm_llama_token_att_fwd2(
prob, prob,
v, v,
attn_out.view(calcu_shape1), attn_out.view(calcu_shape1),

View File

@ -28,7 +28,6 @@ def run_llama_test(args):
tokenizer.pad_token_id = tokenizer.unk_token_id tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half() model = model.half()
model.config
shard_config = ShardConfig( shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}

View File

@ -0,0 +1,81 @@
import os
import warnings
import torch
import torch.distributed as dist
import argparse
from packaging import version
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from transformers import AutoModelForCausalLM, AutoTokenizer
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 1
BATCH_SIZE = 4
MAX_INPUT_LEN = 32
MAX_OUTPUT_LEN = 128
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def run_llama_test(test_config, args):
model_path = args.path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.unk_token_id
model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()
text = ["Introduce London.", "What is the genus of Poodle?"]
input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True)
print(input_ids)
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
extra_kwargs={"inference_only": True})
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, **generate_kwargs)
assert outputs is not None
if not dist.is_initialized() or dist.get_rank() == 0:
for o in outputs:
output_text = tokenizer.decode(o)
print(output_text)
def check_llama(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test(args=args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama(args):
spawn(check_llama, args.tp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", type=str, default = "hpcai-tech/Colossal-LLaMA-2-7b-base", help="Model path")
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size")
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
parser.add_argument(
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
)
args = parser.parse_args()
test_llama(args)

View File

@ -12,7 +12,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package
torchrec==0.2.0 torchrec==0.2.0
contexttimer contexttimer
einops einops
triton==2.0.0.dev20221202 triton==2.1.0
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece SentencePiece
ninja ninja

View File

@ -41,7 +41,6 @@ def test_llama_context_attention():
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose( assert torch.allclose(
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3
), "outputs from triton and torch are not matched" ), "outputs from triton and torch are not matched"