[Inference]Adapt to baichuan2 13B (#5614)

* adapt to baichuan2 13B

* adapt to baichuan2 13B

* change BAICHUAN_MODEL_NAME_OR_PATH

* fix test_decoding_attn.py

* Modifications based on review comments.

* change BAICHUAN_MODEL_NAME_OR_PATH

* mv attn mask processes to test flash decoding

* mv get_alibi_slopes baichuan modeling

* fix bugs in test_baichuan.py
pull/5674/head
yuehuayingxueluo 2024-04-25 23:11:30 +08:00 committed by GitHub
parent f342a93871
commit 3c91e3f176
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 786 additions and 134 deletions

View File

@ -60,4 +60,5 @@ class FDIntermTensors(metaclass=SingletonMeta):
self._mid_output_lse = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._tensors_initialized = True

View File

@ -64,8 +64,15 @@ class KVCacheManager:
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
if hasattr(config, "num_key_value_heads"):
self.kv_head_num = getattr(config, "num_key_value_heads")
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
else:
self.kv_head_num = self.head_num
assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"

View File

@ -1,19 +1,83 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
rms_layernorm,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
def baichuan_rmsnorm_forward(
self,
hidden_states: torch.Tensor,
norm_output: torch.Tensor,
residual: torch.Tensor = None,
use_cuda_kernel: bool = True,
):
# Used to address the issue of inconsistent epsilon variable names in baichuan2 7b and 13b.
if hasattr(self, "variance_epsilon"):
eps = self.variance_epsilon
elif hasattr(self, "epsilon"):
eps = self.epsilon
else:
TypeError(
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
)
if use_cuda_kernel:
if residual is not None:
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
return hidden_states, residual
if norm_output is None:
norm_output = torch.empty_like(hidden_states)
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, eps)
return norm_output, hidden_states
else:
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
class NopadBaichuanAttention(nn.Module):
def __init__(
self,
@ -39,9 +103,11 @@ class NopadBaichuanAttention(nn.Module):
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
# Used to adapt llama_base_attn_forward
self.num_key_value_heads = self.num_heads
self.alibi_slopes = None
self.use_alibi_attn = False
if self.hidden_size == 5120:
self.use_alibi_attn = True
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
@ -112,26 +178,124 @@ class NopadBaichuanAttention(nn.Module):
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
return NopadLlamaAttention.forward(
self,
hidden_states=hidden_states,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
is_verifier=is_verifier,
tokens_to_verify=tokens_to_verify,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
token_nums = hidden_states.size(0)
# fused qkv
hidden_states = hidden_states.expand(3, -1, -1)
query_states, key_states, value_states = (
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)
block_size = k_cache.size(-2)
if is_prompts:
if (
not is_verifier
and use_cuda_kernel
and query_states.dtype != torch.float32
and use_flash_attn2
and not self.use_alibi_attn
):
# flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
)
attn_output = attn_output.view(token_nums, -1)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel:
if not self.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
high_precision,
)
else:
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
return attn_output
# NOTE This will cause difference as out length increases.
class NopadBaichuanMLP(nn.Module):

View File

@ -1,12 +1,15 @@
import torch.nn as nn
from torch.nn import Parameter
from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP
from colossalai.inference.modeling.models.nopadding_baichuan import (
NopadBaichuanAttention,
NopadBaichuanMLP,
baichuan_rmsnorm_forward,
)
from colossalai.inference.modeling.models.nopadding_llama import (
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
llama_rmsnorm_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
@ -21,26 +24,30 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
policy = super().module_policy()
decoder_attribute_replacement = {
"lm_head.weight": Parameter(
nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
),
"lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False),
}
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
policy["DecoderLayer"] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
),
]
)
# used for relpacing Baichuan 7B/13B decoder layer
for layer_name in ["DecoderLayer", "BaichuanLayer"]:
policy[layer_name] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
),
]
)
self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name
)
self.append_or_create_method_replacement(
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
@ -48,11 +55,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
self.append_or_create_method_replacement(
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
)
self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
)
self.append_or_create_method_replacement(
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
)
return policy

View File

@ -185,6 +185,192 @@ def _fwd_context_paged_attention_kernel(
return
# Triton 2.1.0
@triton.jit
def _alibi_fwd_context_paged_attention_kernel(
Q,
K,
V,
O,
KCache,
VCache,
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
batch_size,
alibi_slopes,
stride_qt,
stride_qh,
stride_qd,
stride_kt,
stride_kh,
stride_kd,
stride_vt,
stride_vh,
stride_vd,
stride_ot,
stride_oh,
stride_od,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
context_lengths,
sm_scale,
KV_GROUPS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
if cur_seq_idx >= batch_size:
return
cur_head_idx = tl.program_id(1)
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
cur_kv_head_idx = cur_head_idx // KV_GROUPS
global_block_start_offest = block_start_m * BLOCK_M
# NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same
tl.static_assert(BLOCK_M == BLOCK_N)
tl.static_assert(BLOCK_N == BLOCK_SIZE)
# get the current sequence length from provided context lengths tensor
cur_seq_len = tl.load(context_lengths + cur_seq_idx)
# NOTE when talking to fused QKV and a nopadding context attention,
# we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum`
# could be considered as the start index of the current sequence.
# FIXME might want to explore better way to get the summation of prev seq lengths.
# `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton.
prev_seq_len_sum = 0
for i in range(0, cur_seq_idx):
prev_seq_len_sum += tl.load(context_lengths + i)
offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
Q_block_ptr = tl.make_block_ptr(
base=Q + offset_q,
shape=(cur_seq_len, HEAD_DIM),
strides=(stride_qt, stride_qd),
offsets=(global_block_start_offest, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + offset_kv,
shape=(HEAD_DIM, cur_seq_len),
strides=(stride_kd, stride_kt),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + offset_kv,
shape=(cur_seq_len, HEAD_DIM),
strides=(stride_vt, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
O_block_ptr = tl.make_block_ptr(
base=O + offset_q,
shape=(cur_seq_len, HEAD_DIM),
strides=(stride_ot, stride_od),
offsets=(global_block_start_offest, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# block table for the current sequence
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
# block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq)
# Consider `block_start_m` as the logical block idx in the current block table,
# as we have BLOCK_M the same size as the block size.
cur_block_table_idx = block_start_m
cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb)
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
offsets_m = global_block_start_offest + tl.arange(0, BLOCK_M)
offsets_n = tl.arange(0, BLOCK_N)
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load alibi_slope
alibi_slope = tl.load(alibi_slopes + cur_head_idx)
m_alibi_offset = tl.arange(0, BLOCK_M)[:, None] + global_block_start_offest
n_alibi_offset = tl.arange(0, BLOCK_N)[None, :]
if global_block_start_offest >= cur_seq_len:
return
Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0))
for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N):
block_start_n = tl.multiple_of(block_start_n, BLOCK_N)
k = tl.load(K_block_ptr, boundary_check=(0, 1))
S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
S_ij += tl.dot(Q_i, k)
S_ij *= sm_scale
S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf"))
alibi = (n_alibi_offset + block_start_n - m_alibi_offset) * alibi_slope
alibi = tl.where((alibi <= 0) & (m_alibi_offset < cur_seq_len), alibi, float("-inf"))
S_ij += alibi
m_ij = tl.max(S_ij, 1) # rowmax(Sij)
m_ij = tl.maximum(m_i, m_ij) # m_ij
S_ij -= m_ij[:, None]
p_ij_hat = tl.exp(S_ij)
scale = tl.exp(m_i - m_ij)
l_ij = scale * l_i + tl.sum(p_ij_hat, 1)
acc = acc * scale[:, None]
v = tl.load(V_block_ptr, boundary_check=(1, 0))
p_ij_hat = p_ij_hat.to(v.type.element_ty)
acc += tl.dot(p_ij_hat, v)
l_i = l_ij
m_i = m_ij
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
acc = acc / l_i[:, None]
tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0))
if cur_head_idx % KV_GROUPS == 0:
# Copy k to corresponding cache block
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kt = global_block_start_offest + tl.arange(0, BLOCK_M)
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
offsets_kcache = (
KCache
+ offset_kvcache
+ offsets_dmodel[None, :] * stride_cached
+ offsets_kcachebs[:, None] * stride_cachebs
)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
# Copy v to corresponding cache block
offsets_vd = offsets_dmodel
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
offsets_vcache = (
VCache
+ offset_kvcache
+ offsets_vcachebs[None, :] * stride_cachebs
+ offsets_dmodel[:, None] * stride_cached
)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
return
def context_attention_unpadded(
q: torch.Tensor, # [num_tokens, num_heads, head_dim]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
@ -195,6 +381,7 @@ def context_attention_unpadded(
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
block_size: int,
output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
alibi_slopes: torch.Tensor = None, # [num_heads]
max_seq_len: int = None,
sm_scale: int = None,
):
@ -226,40 +413,78 @@ def context_attention_unpadded(
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
_fwd_context_paged_attention_kernel[grid](
q,
k,
v,
output,
k_cache,
v_cache,
block_tables,
num_seqs,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
output.stride(0),
head_dim,
1,
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
context_lengths,
sm_scale,
num_kv_group,
block_size,
HEAD_DIM=Lk,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
if alibi_slopes is not None:
_alibi_fwd_context_paged_attention_kernel[grid](
q,
k,
v,
output,
k_cache,
v_cache,
block_tables,
num_seqs,
alibi_slopes,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
output.stride(0),
head_dim,
1,
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
context_lengths,
sm_scale,
num_kv_group,
block_size,
HEAD_DIM=Lk,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
else:
_fwd_context_paged_attention_kernel[grid](
q,
k,
v,
output,
k_cache,
v_cache,
block_tables,
num_seqs,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
output.stride(0),
head_dim,
1,
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
context_lengths,
sm_scale,
num_kv_group,
block_size,
HEAD_DIM=Lk,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
return output

View File

@ -124,6 +124,129 @@ def _flash_decoding_fwd_kernel(
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
# Triton 2.1.0
@triton.jit
def _alibi_flash_decoding_fwd_kernel(
Q, # [batch_size * q_len, head_num, head_dim]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
kv_seq_len, # [batch_size]
q_len,
batch_size,
alibi_slopes,
stride_qt,
stride_qh,
stride_qd,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
stride_mid_ot,
stride_mid_oh,
stride_mid_ob,
stride_mid_od,
stride_mid_o_lset,
stride_mid_o_lseh,
stride_mid_o_lseb,
sm_scale,
KV_GROUPS: tl.constexpr,
BLOCK_KV: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size:
return
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_head_idx = tl.program_id(1)
block_start_kv = tl.program_id(2) # for splitting k/v
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
# and then support calculating multiple kv cache blocks on an instance
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
# get the current (kv) sequence length
# cur_token_off is used as a "mask" here for spec-dec during verification process
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
q = tl.load(Q + offsets_q)
# block table for the current sequence
block_table_ptr = block_tables + cur_seq_idx * stride_bts
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
cur_occupied_size = tl.where(
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
)
tl.device_assert(cur_occupied_size >= 0)
cur_kv_head_idx = cur_head_idx // KV_GROUPS
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=VCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
k_cur_block = tl.load(K_block_ptr)
v_cur_block = tl.load(V_block_ptr)
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
# use block size of the paged/blocked kv cache
S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
alibi_slope = tl.load(alibi_slopes + cur_head_idx)
position_k_offset = block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE)
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
# Refer to https://github.com/openai/triton/discussions/895
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
S_ij *= sm_scale
S_ij -= alibi_slope * (cur_kv_seq_len - 1 - position_k_offset)
S_ij = tl.where(cur_kv_seq_len > position_k_offset, S_ij, float("-inf"))
m = tl.max(S_ij, 0)
S_ij -= m
p_ij_hat = tl.exp(S_ij)
l = tl.sum(p_ij_hat, 0)
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
acc = acc / l
offsets_mid_o = (
cur_token_idx * stride_mid_ot
+ cur_head_idx * stride_mid_oh
+ block_start_kv * stride_mid_ob
+ offsets_dmodel * stride_mid_od
)
tl.store(mid_o + offsets_mid_o, acc)
offsets_mid_o_lse = (
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
)
# logsumexp L^(j) = m^(j) + log(l^(j))
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
# Triton 2.1.0
@triton.jit
def _flash_decoding_fwd_reduce_kernel(
@ -197,9 +320,10 @@ def flash_decoding_attention(
output: torch.Tensor = None,
mid_output: torch.Tensor = None,
mid_output_lse: torch.Tensor = None,
alibi_slopes: torch.Tensor = None,
sm_scale: int = None,
kv_group_num: int = 1,
q_len: int = 1,
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
):
"""
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
@ -220,6 +344,7 @@ def flash_decoding_attention(
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
q_len > 1 only for verification process in speculative-decoding.
alibi_slopes (torch.Tensor): [num_heads] alibi slopes used for alibi flash decoding.
block_size (int): Size of each block in the blocked key/value cache.
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
@ -280,38 +405,74 @@ def flash_decoding_attention(
num_heads,
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
)
_flash_decoding_fwd_kernel[grid](
q,
k_cache,
v_cache,
block_tables,
mid_output,
mid_output_lse,
kv_seq_len,
q_len,
bsz,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),
mid_output.stride(3),
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size,
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
if alibi_slopes is not None:
_alibi_flash_decoding_fwd_kernel[grid](
q,
k_cache,
v_cache,
block_tables,
mid_output,
mid_output_lse,
kv_seq_len,
q_len,
bsz,
alibi_slopes,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),
mid_output.stride(3),
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size,
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
else:
_flash_decoding_fwd_kernel[grid](
q,
k_cache,
v_cache,
block_tables,
mid_output,
mid_output_lse,
kv_seq_len,
q_len,
bsz,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),
mid_output.stride(3),
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size,
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
grid = (triton.next_power_of_2(bsz * q_len), num_heads)
_flash_decoding_fwd_reduce_kernel[grid](

View File

@ -12,7 +12,8 @@ from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "/home/data/models/Baichuan2-13B-Base"
def setup_seed(seed):
@ -22,12 +23,10 @@ def setup_seed(seed):
random.seed(seed)
def check_inference_engine(use_engine=False, prompt_template=None):
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
BAICHUAN_MODEL_NAME_OR_PATH, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
).cuda()
model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda()
model = model.eval()
inputs = [
@ -35,17 +34,24 @@ def check_inference_engine(use_engine=False, prompt_template=None):
]
output_len = 38
do_sample = False
do_sample = do_sample
if do_sample:
top_p = 0.5
top_k = 50
else:
top_p = None
top_k = None
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample)
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
@ -57,6 +63,8 @@ def check_inference_engine(use_engine=False, prompt_template=None):
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
@ -67,9 +75,15 @@ def check_inference_engine(use_engine=False, prompt_template=None):
@parameterize("prompt_template", [None, "baichuan"])
def check_output_consistency(prompt_template):
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
@parameterize("do_sample", [True, False])
@parameterize("use_cuda_kernel", [True, False])
def check_output_consistency(prompt_template, do_sample, use_cuda_kernel):
cai_outputs = check_inference_engine(
use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
)
transformer_outputs = check_inference_engine(
use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
)
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"

View File

@ -64,10 +64,6 @@ def torch_attn_ref(
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
)
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)

View File

@ -2,6 +2,7 @@ import pytest
import torch
from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
@ -19,8 +20,31 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 32
def _fill_with_neg_inf(t):
return t.float().fill_(float("-inf")).type_as(t)
# alibi mask calculation adapted from https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py
def generate_alibi_mask(slopes, num_heads, max_seq_len, device):
token_position = torch.arange(max_seq_len, device=device) - max_seq_len + 1
token_position = token_position.unsqueeze(0).unsqueeze(0).expand(num_heads, -1, -1)
diag = torch.diag(token_position[0])
token_position = token_position - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * token_position
alibi = alibi.view(num_heads, 1, max_seq_len)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_seq_len, max_seq_len], device=device)), 1)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
return alibi_mask
def torch_attn_unpad(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor, num_heads: int, num_kv_heads: int
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
slopes: torch.Tensor = None,
):
# Process sequence one by one and concatenate them together.
# q,k,v [num_tokens(sum(context_lengths)), num_heads, head_dim]
@ -35,6 +59,10 @@ def torch_attn_unpad(
mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device)
mask[mask == 0.0] = float("-inf")
if slopes != None:
alibi_mask = generate_alibi_mask(slopes, num_heads, seq_len, q.device)
mask = mask + alibi_mask
torch_attn_ref_out = torch_attn_ref(
q[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
k[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
@ -60,6 +88,7 @@ def torch_attn_unpad(
@pytest.mark.parametrize("num_attn_heads", [16])
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_context_attention(
bsz: int,
block_size: int,
@ -67,6 +96,7 @@ def test_context_attention(
num_attn_heads: int,
kv_group_num: int,
same_context_len: bool,
use_alibi_slopes: bool,
):
torch.manual_seed(123)
# It's necessary to clear cache here.
@ -79,6 +109,10 @@ def test_context_attention(
max_seq_len = max_num_blocks_per_seq * block_size
dtype = torch.float16
device = get_current_device()
alibi_slopes = None
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(num_attn_heads, device)
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
@ -100,12 +134,19 @@ def test_context_attention(
_, num_heads, head_dim = q_unpad.shape
out_triton = context_attention_unpadded(
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
q_unpad,
k_unpad,
v_unpad,
k_cache_triton,
v_cache_triton,
context_lengths,
block_tables,
block_size,
alibi_slopes=alibi_slopes,
)
out_triton = out_triton.view(-1, num_heads, head_dim)
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads, alibi_slopes)
assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-3)
@ -114,4 +155,4 @@ def test_context_attention(
if __name__ == "__main__":
test_context_attention(4, 32, 8, 16, 1, True)
test_context_attention(4, 32, 8, 16, 1, True, True)

View File

@ -1,7 +1,9 @@
import numpy as np
import pytest
import torch
from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
@ -10,6 +12,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
torch_attn_ref,
)
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
try:
import triton # noqa
@ -24,6 +27,13 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 128
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
def prepare_data(
bsz: int,
num_attn_heads: int,
@ -64,6 +74,7 @@ def prepare_data(
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("q_len", [1, 5])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_flash_decoding(
bsz: int,
block_size: int,
@ -72,6 +83,7 @@ def test_flash_decoding(
kv_group_num: int,
same_context_len: bool,
q_len: int,
use_alibi_slopes: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
@ -83,6 +95,14 @@ def test_flash_decoding(
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float16
device = get_current_device()
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(num_attn_heads, device)
# Currently, alibi flash decoding does not support q_len>1.
q_len = 1
else:
alibi_slopes = None
q, k_unpad, v_unpad, kv_lengths = prepare_data(
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device
)
@ -92,6 +112,17 @@ def test_flash_decoding(
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, num_attn_heads, max_kv_len_in_b, q.device)
attention_mask = attention_mask + alibi_mask
if q_len == 1:
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
out_torch = torch_attn_ref(
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
)
@ -130,14 +161,21 @@ def test_flash_decoding(
output,
mid_output,
mid_output_lse,
alibi_slopes=alibi_slopes,
sm_scale=sm_scale,
kv_group_num=kv_group_num,
q_len=q_len,
) # [bsz * q_len, num_heads, head_dim]
assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
rtol = 1e-4
# After the shape becomes larger, some data elements are too small, leading to excessively large relative errors.
if bsz == 32 and use_alibi_slopes:
rtol = 100
numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol)
if __name__ == "__main__":
test_flash_decoding(16, 32, 32, 16, 1, True)
test_flash_decoding(16, 32, 32, 16, 1, True, 1, True)