diff --git a/colossalai/inference/build.sh b/colossalai/inference/build.sh deleted file mode 100644 index 6a73f6f0b..000000000 --- a/colossalai/inference/build.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env bash - -# install triton -pip install triton -pip install transformers - -# install lightllm and flash-attention -mkdir 3rdParty -cd 3rdParty -git clone https://github.com/ModelTC/lightllm -cd lightllm -git checkout 28c1267cfca536b7b4f28e921e03de735b003039 -pip install -e . -cd .. - -git clone -recursive https://github.com/Dao-AILab/flash-attention -cd flash-attention -pip install -e . - -cd ../../ - - - - diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py index 2dd1858d6..b7bc94d0e 100644 --- a/colossalai/inference/engine/modeling/llama.py +++ b/colossalai/inference/engine/modeling/llama.py @@ -27,9 +27,15 @@ except: print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") HAS_LIGHTLLM_KERNEL = False +try: + from colossalai.kernel.triton.flash_decoding import token_flash_decoding + HAS_TRITON_FLASH_DECODING_KERNEL = True +except: + print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") + HAS_TRITON_FLASH_DECODING_KERNEL = False + try: from flash_attn import flash_attn_with_kvcache - HAS_FLASH_KERNEL = True except: HAS_FLASH_KERNEL = False @@ -42,7 +48,6 @@ def rotate_half(x): x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) - def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] @@ -67,7 +72,6 @@ def llama_triton_context_attention( attn_output, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) else: @@ -78,7 +82,6 @@ def llama_triton_context_attention( attn_output, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) else: @@ -90,13 +93,20 @@ def llama_triton_context_attention( attn_output, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) - -def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): - assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1): + if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1: + token_flash_decoding(q = query_states, + o_tensor = attn_output, + infer_state = infer_state, + q_head_num = q_head_num, + head_dim = head_dim, + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]) + return + if num_key_value_groups == 1: token_attention_fwd( query_states, @@ -106,7 +116,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) else: @@ -118,7 +127,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, infer_state.other_kv_index, ) @@ -451,10 +459,14 @@ class LlamaInferenceForwards: ) if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) - llama_triton_token_attention( - query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups - ) + llama_triton_token_attention(query_states = query_states, + attn_output = attn_output, + infer_state = infer_state, + num_key_value_groups = self.num_key_value_groups, + q_head_num = q_len * self.num_heads, + head_dim = self.head_dim) else: self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 1ad7a80eb..3d9a23d2f 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -137,6 +137,7 @@ if HAS_TRITON: tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return else: + # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11 @triton.jit def _context_flash_attention_kernel_2( Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py new file mode 100644 index 000000000..9b7b27fa1 --- /dev/null +++ b/colossalai/kernel/triton/flash_decoding.py @@ -0,0 +1,50 @@ +# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py +import torch +try: + from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 + from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 + HAS_LIGHTLLM_KERNEL = True +except: + print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") + HAS_LIGHTLLM_KERNEL = False + + +if HAS_LIGHTLLM_KERNEL: + def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): + BLOCK_SEQ = 256 + batch_size = infer_state.batch_size + max_len_in_batch = infer_state.max_len_in_batch + + + calcu_shape1 = (batch_size, q_head_num, head_dim) + + if getattr(infer_state, 'mid_o', None) is None: + infer_state.mid_o = torch.empty([batch_size, + q_head_num, + max_len_in_batch // BLOCK_SEQ + 1, + head_dim], + dtype=torch.float32, + device="cuda") + infer_state.mid_o_logexpsum = torch.empty([batch_size, + q_head_num, + max_len_in_batch // BLOCK_SEQ + 1], + dtype=torch.float32, + device="cuda") + + mid_o = infer_state.mid_o + mid_o_logexpsum = infer_state.mid_o_logexpsum + + flash_decode_stage1(q.view(calcu_shape1), + cache_k, + cache_v, + infer_state.block_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + mid_o, + mid_o_logexpsum, + BLOCK_SEQ) + flash_decode_stage2(mid_o, + mid_o_logexpsum, + infer_state.seq_len, + o_tensor.view(calcu_shape1), + BLOCK_SEQ) diff --git a/examples/inference/hybrid_llama.py b/examples/inference/hybrid_llama.py index bdfa4e5e8..1bd34afef 100644 --- a/examples/inference/hybrid_llama.py +++ b/examples/inference/hybrid_llama.py @@ -75,11 +75,11 @@ def run_tp_pipeline_inference(rank, world_size, port, args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", type=str, help="Model path", required=True) - parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size") - parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Tensor parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=8, help="Maximum batch size") - parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length") - parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length") + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-pp", "--pp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=64, help="Maximum batch size") + parser.add_argument("--max_input_len", type=int, default=512, help="Maximum input length") + parser.add_argument("--max_output_len", type=int, default=256, help="Maximum output length") parser.add_argument("--micro_batch_size", type=int, default=2, help="Micro batch size") args = parser.parse_args() diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt index 461dcb23b..3151504df 100644 --- a/requirements/requirements-infer.txt +++ b/requirements/requirements-infer.txt @@ -2,6 +2,6 @@ transformers==4.34.0 packaging ninja auto-gptq==0.5.0 -git+https://github.com/ModelTC/lightllm.git@28c1267cfca536b7b4f28e921e03de735b003039 +git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8 git+https://github.com/facebookresearch/xformers.git@main#egg=xformers git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9