mirror of https://github.com/hpcaitech/ColossalAI
[Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965)
* adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com>pull/4992/head
parent
cf579ff46d
commit
459a88c806
|
@ -34,11 +34,13 @@ In this section we discuss how the colossal inference works and integrates with
|
||||||
- [x] policy
|
- [x] policy
|
||||||
- [x] context forward
|
- [x] context forward
|
||||||
- [x] token forward
|
- [x] token forward
|
||||||
|
- [x] support flash-decoding
|
||||||
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
|
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
|
||||||
- [ ] Support all models
|
- [ ] Support all models
|
||||||
- [x] Llama
|
- [x] Llama
|
||||||
|
- [x] Llama-2
|
||||||
- [x] Bloom
|
- [x] Bloom
|
||||||
- [ ] Chatglm2
|
- [x] Chatglm2
|
||||||
- [ ] Benchmarking for all models
|
- [ ] Benchmarking for all models
|
||||||
|
|
||||||
## Get started
|
## Get started
|
||||||
|
@ -68,6 +70,12 @@ git clone https://github.com/ModelTC/lightllm
|
||||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||||
cd lightllm
|
cd lightllm
|
||||||
pip3 install -e .
|
pip3 install -e .
|
||||||
|
|
||||||
|
# also, install xformers from source:
|
||||||
|
pip install ninja
|
||||||
|
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
|
||||||
|
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Docker
|
### Docker
|
||||||
|
@ -89,7 +97,10 @@ git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||||
cd lightllm
|
cd lightllm
|
||||||
pip3 install -e .
|
pip3 install -e .
|
||||||
|
|
||||||
|
# install xformers from source
|
||||||
|
pip install ninja
|
||||||
|
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
|
||||||
|
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
||||||
```
|
```
|
||||||
|
|
||||||
### Dive into fast-inference!
|
### Dive into fast-inference!
|
||||||
|
|
|
@ -311,6 +311,7 @@ class TPInferEngine:
|
||||||
seq_start_indexes[i] = start_index
|
seq_start_indexes[i] = start_index
|
||||||
start_index += curr_seq_len
|
start_index += curr_seq_len
|
||||||
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
||||||
|
|
||||||
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
|
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
|
||||||
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
||||||
batch_infer_state.seq_len = seq_lengths.to("cuda")
|
batch_infer_state.seq_len = seq_lengths.to("cuda")
|
||||||
|
|
|
@ -19,6 +19,12 @@ from transformers.utils import logging
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
|
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_bloom_context_attention_fwd
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
|
|
||||||
def generate_alibi(n_head, dtype=torch.float16):
|
def generate_alibi(n_head, dtype=torch.float16):
|
||||||
"""
|
"""
|
||||||
|
@ -460,7 +466,10 @@ class BloomInferenceForwards:
|
||||||
# output = self.output[:batch_size*q_length, :, :]
|
# output = self.output[:batch_size*q_length, :, :]
|
||||||
output = torch.empty_like(q)
|
output = torch.empty_like(q)
|
||||||
|
|
||||||
bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
|
if HAS_LIGHTLLM_KERNEL:
|
||||||
|
lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len)
|
||||||
|
else:
|
||||||
|
bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
|
||||||
|
|
||||||
context_layer = output.view(batch_size, q_length, H * D_HEAD)
|
context_layer = output.view(batch_size, q_length, H * D_HEAD)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
import math
|
||||||
|
import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
@ -10,24 +12,11 @@ from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttention
|
||||||
|
|
||||||
from ._utils import copy_kv_to_mem_cache
|
from ._utils import copy_kv_to_mem_cache
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm import layernorm_ops, pos_encoding_ops
|
|
||||||
|
|
||||||
rms_norm = layernorm_ops.rms_norm
|
|
||||||
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
|
|
||||||
HAS_VLLM_KERNERL = True
|
|
||||||
except:
|
|
||||||
print("fall back to original rotary_embedding_neox of huggingface")
|
|
||||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
|
||||||
print(
|
|
||||||
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
|
|
||||||
)
|
|
||||||
HAS_VLLM_KERNERL = False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||||
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||||
)
|
)
|
||||||
|
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd
|
||||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||||
|
|
||||||
HAS_LIGHTLLM_KERNEL = True
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
@ -35,6 +24,13 @@ except:
|
||||||
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||||
HAS_LIGHTLLM_KERNEL = False
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_with_kvcache
|
||||||
|
HAS_FLASH_KERNEL = True
|
||||||
|
except:
|
||||||
|
HAS_FLASH_KERNEL = False
|
||||||
|
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
@ -54,6 +50,71 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1):
|
||||||
|
if num_key_value_groups == 1:
|
||||||
|
if HAS_LIGHTLLM_KERNEL is False:
|
||||||
|
llama_context_attn_fwd(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_output,
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
# infer_state.cache_manager.past_key_values_length,
|
||||||
|
infer_state.max_len_in_batch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lightllm_context_attention_fwd(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_output,
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
# infer_state.cache_manager.past_key_values_length,
|
||||||
|
infer_state.max_len_in_batch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
|
||||||
|
lightllm_llama2_context_attention_fwd(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_output,
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
# infer_state.cache_manager.past_key_values_length,
|
||||||
|
infer_state.max_len_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
|
||||||
|
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
|
||||||
|
if num_key_value_groups == 1:
|
||||||
|
token_attention_fwd(
|
||||||
|
query_states,
|
||||||
|
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||||
|
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||||
|
attn_output,
|
||||||
|
infer_state.block_loc,
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
# infer_state.cache_manager.past_key_values_length,
|
||||||
|
infer_state.max_len_in_batch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
Llama2TokenAttentionForwards.token_attn(
|
||||||
|
query_states,
|
||||||
|
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||||
|
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||||
|
attn_output,
|
||||||
|
infer_state.block_loc,
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
# infer_state.cache_manager.past_key_values_length,
|
||||||
|
infer_state.max_len_in_batch,
|
||||||
|
infer_state.other_kv_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaInferenceForwards:
|
class LlamaInferenceForwards:
|
||||||
"""
|
"""
|
||||||
|
@ -204,7 +265,8 @@ class LlamaInferenceForwards:
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def llama_decoder_layer_forward(
|
def llama_decoder_layer_forward(
|
||||||
self: LlamaDecoderLayer,
|
self: LlamaDecoderLayer,
|
||||||
|
@ -247,6 +309,7 @@ class LlamaInferenceForwards:
|
||||||
outputs += (present_key_value,)
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def llama_flash_attn_kvcache_forward(
|
def llama_flash_attn_kvcache_forward(
|
||||||
|
@ -295,27 +358,8 @@ class LlamaInferenceForwards:
|
||||||
infer_state.cache_manager,
|
infer_state.cache_manager,
|
||||||
)
|
)
|
||||||
attn_output = torch.empty_like(query_states)
|
attn_output = torch.empty_like(query_states)
|
||||||
|
|
||||||
if self.num_key_value_groups == 1:
|
llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
|
||||||
llama_context_attn_fwd(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attn_output,
|
|
||||||
infer_state.start_loc,
|
|
||||||
infer_state.seq_len,
|
|
||||||
infer_state.max_len_in_batch,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
lightllm_llama2_context_attention_fwd(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attn_output,
|
|
||||||
infer_state.start_loc,
|
|
||||||
infer_state.seq_len,
|
|
||||||
infer_state.max_len_in_batch,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if infer_state.decode_is_contiguous:
|
if infer_state.decode_is_contiguous:
|
||||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||||
|
@ -337,35 +381,26 @@ class LlamaInferenceForwards:
|
||||||
infer_state.decode_mem_index,
|
infer_state.decode_mem_index,
|
||||||
infer_state.cache_manager,
|
infer_state.cache_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# second token and follows
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
# kv = torch.stack((key_states, value_states), dim=2)
|
if HAS_LIGHTLLM_KERNEL:
|
||||||
# (batch_size, seqlen, nheads, headdim)
|
attn_output = torch.empty_like(query_states)
|
||||||
attn_output = torch.empty_like(query_states)
|
llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
|
||||||
|
|
||||||
if self.num_key_value_groups == 1:
|
|
||||||
token_attention_fwd(
|
|
||||||
query_states,
|
|
||||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
|
||||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
|
||||||
attn_output,
|
|
||||||
infer_state.block_loc,
|
|
||||||
infer_state.start_loc,
|
|
||||||
infer_state.seq_len,
|
|
||||||
infer_state.max_len_in_batch,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
Llama2TokenAttentionForwards.token_attn(
|
heads_per_group = self.num_heads // self.num_key_value_heads
|
||||||
query_states,
|
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
||||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
|
||||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
|
||||||
attn_output,
|
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
|
||||||
infer_state.block_loc,
|
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
||||||
infer_state.start_loc,
|
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
||||||
infer_state.seq_len,
|
|
||||||
infer_state.max_len_in_batch,
|
attn_output = flash_attn_with_kvcache(q = query_states,
|
||||||
infer_state.other_kv_index,
|
k_cache = copy_cache_k,
|
||||||
)
|
v_cache = copy_cache_v,
|
||||||
|
softmax_scale = 1/ math.sqrt(self.head_dim),
|
||||||
|
causal = True)
|
||||||
|
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
@ -374,22 +409,3 @@ class LlamaInferenceForwards:
|
||||||
# return past_key_value as None
|
# return past_key_value as None
|
||||||
return attn_output, None, None
|
return attn_output, None, None
|
||||||
|
|
||||||
|
|
||||||
def get_llama_vllm_rmsnorm_forward():
|
|
||||||
if HAS_VLLM_KERNERL:
|
|
||||||
|
|
||||||
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
|
||||||
x = hidden_states
|
|
||||||
out = torch.empty_like(x)
|
|
||||||
rms_norm(
|
|
||||||
out,
|
|
||||||
x,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
return _vllm_rmsnorm_forward
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
from ..modeling._utils import init_to_get_rotary
|
from ..modeling._utils import init_to_get_rotary
|
||||||
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
from ..modeling.llama import LlamaInferenceForwards
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||||
|
@ -105,9 +105,6 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
infer_forward = None
|
infer_forward = None
|
||||||
if HAS_TRITON_RMSNORM:
|
if HAS_TRITON_RMSNORM:
|
||||||
infer_forward = get_triton_rmsnorm_forward()
|
infer_forward = get_triton_rmsnorm_forward()
|
||||||
else:
|
|
||||||
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
|
|
||||||
infer_forward = get_llama_vllm_rmsnorm_forward()
|
|
||||||
|
|
||||||
if infer_forward is not None:
|
if infer_forward is not None:
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
|
@ -155,39 +154,43 @@ if HAS_TRITON:
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||||
|
|
||||||
_context_flash_attention_kernel[grid](
|
if triton.__version__ < "2.1.0":
|
||||||
q,
|
_context_flash_attention_kernel[grid](
|
||||||
k,
|
q,
|
||||||
v,
|
k,
|
||||||
sm_scale,
|
v,
|
||||||
b_start_loc,
|
sm_scale,
|
||||||
b_seq_len,
|
b_start_loc,
|
||||||
tmp,
|
b_seq_len,
|
||||||
alibi,
|
tmp,
|
||||||
o,
|
alibi,
|
||||||
q.stride(0),
|
o,
|
||||||
q.stride(1),
|
q.stride(0),
|
||||||
q.stride(2),
|
q.stride(1),
|
||||||
k.stride(0),
|
q.stride(2),
|
||||||
k.stride(1),
|
k.stride(0),
|
||||||
k.stride(2),
|
k.stride(1),
|
||||||
v.stride(0),
|
k.stride(2),
|
||||||
v.stride(1),
|
v.stride(0),
|
||||||
v.stride(2),
|
v.stride(1),
|
||||||
o.stride(0),
|
v.stride(2),
|
||||||
o.stride(1),
|
o.stride(0),
|
||||||
o.stride(2),
|
o.stride(1),
|
||||||
tmp.stride(0),
|
o.stride(2),
|
||||||
tmp.stride(1),
|
tmp.stride(0),
|
||||||
tmp.stride(2),
|
tmp.stride(1),
|
||||||
# manually setting this blcok num, we can use tuning config to futher speed-up
|
tmp.stride(2),
|
||||||
BLOCK_M=BLOCK,
|
# manually setting this blcok num, we can use tuning config to futher speed-up
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_M=BLOCK,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_DMODEL=Lk,
|
||||||
num_warps=num_warps,
|
BLOCK_N=BLOCK,
|
||||||
num_stages=1,
|
num_warps=num_warps,
|
||||||
)
|
num_stages=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0")
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -207,36 +210,40 @@ if HAS_TRITON:
|
||||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
# num_warps = 4
|
# num_warps = 4
|
||||||
_context_flash_attention_kernel[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
sm_scale,
|
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
|
||||||
tmp,
|
|
||||||
None,
|
|
||||||
o,
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(2),
|
|
||||||
k.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(2),
|
|
||||||
v.stride(0),
|
|
||||||
v.stride(1),
|
|
||||||
v.stride(2),
|
|
||||||
o.stride(0),
|
|
||||||
o.stride(1),
|
|
||||||
o.stride(2),
|
|
||||||
tmp.stride(0),
|
|
||||||
tmp.stride(1),
|
|
||||||
tmp.stride(2),
|
|
||||||
BLOCK_M=BLOCK,
|
|
||||||
BLOCK_DMODEL=Lk,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if triton.__version__ < "2.1.0":
|
||||||
|
_context_flash_attention_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
sm_scale,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
tmp,
|
||||||
|
None,
|
||||||
|
o,
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(2),
|
||||||
|
tmp.stride(0),
|
||||||
|
tmp.stride(1),
|
||||||
|
tmp.stride(2),
|
||||||
|
BLOCK_M=BLOCK,
|
||||||
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0")
|
||||||
|
|
||||||
return
|
return
|
|
@ -105,8 +105,8 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
||||||
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
|
parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size")
|
||||||
parser.add_argument("--input_len", type=int, default=256, help="Maximum input length")
|
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
||||||
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
|
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
|
||||||
|
|
|
@ -10,6 +10,12 @@ from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import lightllm
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
TP_SIZE = 2
|
TP_SIZE = 2
|
||||||
MAX_BATCH_SIZE = 4
|
MAX_BATCH_SIZE = 4
|
||||||
MAX_INPUT_LEN = 16
|
MAX_INPUT_LEN = 16
|
||||||
|
@ -52,7 +58,7 @@ def check_bloom(rank, world_size, port):
|
||||||
run()
|
run()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -12,6 +12,12 @@ from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import Ch
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
|
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import lightllm
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
TPSIZE = 1
|
TPSIZE = 1
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
|
@ -61,7 +67,7 @@ def check_chatglm2(rank, world_size, port):
|
||||||
run_chatglm2_test()
|
run_chatglm2_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -12,6 +12,12 @@ from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import lightllm
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
TPSIZE = 2
|
TPSIZE = 2
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
|
@ -57,7 +63,7 @@ def check_llama(rank, world_size, port):
|
||||||
run_llama_test()
|
run_llama_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -12,6 +12,12 @@ from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import lightllm
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
TPSIZE = 2
|
TPSIZE = 2
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
|
@ -55,7 +61,7 @@ def check_llama(rank, world_size, port):
|
||||||
run_llama_test()
|
run_llama_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm import layernorm_ops
|
|
||||||
|
|
||||||
rms_norm = layernorm_ops.rms_norm
|
|
||||||
HAS_VLLM_KERNERL = True
|
|
||||||
except:
|
|
||||||
print("please install vllm kernels to install rmsnorm")
|
|
||||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
|
||||||
HAS_VLLM_KERNERL = False
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
"""
|
|
||||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon):
|
|
||||||
x = hidden_states
|
|
||||||
out = torch.empty_like(x)
|
|
||||||
rms_norm(
|
|
||||||
out,
|
|
||||||
x,
|
|
||||||
weight,
|
|
||||||
variance_epsilon,
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
|
|
||||||
def test_rmsnorm():
|
|
||||||
data = torch.randn((1024, 64), dtype=torch.float16, device="cuda")
|
|
||||||
hg_rms = LlamaRMSNorm(64)
|
|
||||||
hg_rms = hg_rms.half().cuda()
|
|
||||||
out_torch = hg_rms(data)
|
|
||||||
out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon)
|
|
||||||
|
|
||||||
check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5)
|
|
||||||
assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_rmsnorm()
|
|
|
@ -1,153 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
|
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm import pos_encoding_ops
|
|
||||||
|
|
||||||
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
|
|
||||||
HAS_VLLM_KERNERL = True
|
|
||||||
except:
|
|
||||||
print("fall back to original rotary_embedding_neox of huggingface")
|
|
||||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
|
||||||
HAS_VLLM_KERNERL = False
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class RefRotaryEmbeddingNeox(nn.Module):
|
|
||||||
"""Reference implementation of the GPT-NeoX style rotary embedding."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
max_position_embeddings: int = 2048,
|
|
||||||
base: int = 10000,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.rotary_dim = dim
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
|
|
||||||
# Create cos and sin embeddings.
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
|
|
||||||
t = torch.arange(max_position_embeddings).float()
|
|
||||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
|
||||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
|
||||||
self.register_buffer("cos_cached", cos, persistent=False)
|
|
||||||
self.register_buffer("sin_cached", sin, persistent=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor, # [num_tokens]
|
|
||||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
|
||||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
query_rot = query[..., : self.rotary_dim]
|
|
||||||
query_pass = query[..., self.rotary_dim :]
|
|
||||||
key_rot = key[..., : self.rotary_dim]
|
|
||||||
key_pass = key[..., self.rotary_dim :]
|
|
||||||
|
|
||||||
query_rot = query_rot.transpose(0, 1)
|
|
||||||
key_rot = key_rot.transpose(0, 1)
|
|
||||||
cos = F.embedding(positions, self.cos_cached)
|
|
||||||
sin = F.embedding(positions, self.sin_cached)
|
|
||||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
||||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
|
||||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
|
||||||
|
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
|
||||||
|
|
||||||
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
|
||||||
return query, key
|
|
||||||
|
|
||||||
|
|
||||||
def run_rotary_embedding_neox(
|
|
||||||
num_tokens: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
max_position: int,
|
|
||||||
rotary_dim: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
base: int = 10000,
|
|
||||||
) -> None:
|
|
||||||
positions = torch.randint(0, max_position, (num_tokens,), device="cuda")
|
|
||||||
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda")
|
|
||||||
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
# Create the rotary embedding.
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
|
||||||
t = torch.arange(max_position).float()
|
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
|
||||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
|
||||||
out_query = query.clone()
|
|
||||||
out_key = key.clone()
|
|
||||||
rotary_embedding_neox(
|
|
||||||
positions,
|
|
||||||
out_query,
|
|
||||||
out_key,
|
|
||||||
head_size,
|
|
||||||
cos_sin_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the reference implementation.
|
|
||||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
|
||||||
dim=rotary_dim,
|
|
||||||
max_position_embeddings=max_position,
|
|
||||||
base=base,
|
|
||||||
).to(dtype=dtype, device="cuda")
|
|
||||||
ref_query, ref_key = ref_rotary_embedding(
|
|
||||||
positions,
|
|
||||||
query.view(num_tokens, num_heads, head_size),
|
|
||||||
key.view(num_tokens, num_heads, head_size),
|
|
||||||
)
|
|
||||||
ref_query = ref_query.view(num_tokens, num_heads * head_size)
|
|
||||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
|
||||||
|
|
||||||
# Compare the results.
|
|
||||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
|
||||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
|
|
||||||
def test_rotary_embedding():
|
|
||||||
run_rotary_embedding_neox(
|
|
||||||
num_tokens=1024,
|
|
||||||
num_heads=8,
|
|
||||||
head_size=64,
|
|
||||||
max_position=8192,
|
|
||||||
rotary_dim=64,
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_rotary_embedding()
|
|
Loading…
Reference in New Issue