mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Adapt to baichuan2 13B (#5614)
* adapt to baichuan2 13B * adapt to baichuan2 13B * change BAICHUAN_MODEL_NAME_OR_PATH * fix test_decoding_attn.py * Modifications based on review comments. * change BAICHUAN_MODEL_NAME_OR_PATH * mv attn mask processes to test flash decoding * mv get_alibi_slopes baichuan modeling * fix bugs in test_baichuan.pypull/5674/head
parent
f342a93871
commit
3c91e3f176
|
@ -60,4 +60,5 @@ class FDIntermTensors(metaclass=SingletonMeta):
|
|||
self._mid_output_lse = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self._tensors_initialized = True
|
||||
|
|
|
@ -64,8 +64,15 @@ class KVCacheManager:
|
|||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
||||
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
|
||||
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
|
||||
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
|
||||
|
||||
if hasattr(config, "num_key_value_heads"):
|
||||
self.kv_head_num = getattr(config, "num_key_value_heads")
|
||||
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
|
||||
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
|
||||
else:
|
||||
self.kv_head_num = self.head_num
|
||||
|
||||
assert (
|
||||
self.kv_head_num % self.tp_size == 0
|
||||
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
|
||||
|
|
|
@ -1,19 +1,83 @@
|
|||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
||||
slopes = torch.pow(base, powers)
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
def baichuan_rmsnorm_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
norm_output: torch.Tensor,
|
||||
residual: torch.Tensor = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
):
|
||||
# Used to address the issue of inconsistent epsilon variable names in baichuan2 7b and 13b.
|
||||
if hasattr(self, "variance_epsilon"):
|
||||
eps = self.variance_epsilon
|
||||
elif hasattr(self, "epsilon"):
|
||||
eps = self.epsilon
|
||||
else:
|
||||
TypeError(
|
||||
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
|
||||
)
|
||||
|
||||
if use_cuda_kernel:
|
||||
if residual is not None:
|
||||
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
|
||||
return hidden_states, residual
|
||||
|
||||
if norm_output is None:
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, eps)
|
||||
return norm_output, hidden_states
|
||||
else:
|
||||
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
|
||||
|
||||
|
||||
class NopadBaichuanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -39,9 +103,11 @@ class NopadBaichuanAttention(nn.Module):
|
|||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
# Used to adapt llama_base_attn_forward
|
||||
self.num_key_value_heads = self.num_heads
|
||||
self.alibi_slopes = None
|
||||
self.use_alibi_attn = False
|
||||
if self.hidden_size == 5120:
|
||||
self.use_alibi_attn = True
|
||||
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
|
||||
|
||||
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
|
||||
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||
|
@ -112,26 +178,124 @@ class NopadBaichuanAttention(nn.Module):
|
|||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
return NopadLlamaAttention.forward(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
block_tables=block_tables,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
sequence_lengths=sequence_lengths,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
is_prompts=is_prompts,
|
||||
is_verifier=is_verifier,
|
||||
tokens_to_verify=tokens_to_verify,
|
||||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
token_nums = hidden_states.size(0)
|
||||
# fused qkv
|
||||
hidden_states = hidden_states.expand(3, -1, -1)
|
||||
query_states, key_states, value_states = (
|
||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||
)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if (
|
||||
not is_verifier
|
||||
and use_cuda_kernel
|
||||
and query_states.dtype != torch.float32
|
||||
and use_flash_attn2
|
||||
and not self.use_alibi_attn
|
||||
):
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=kv_seq_len,
|
||||
max_seqlen_k=kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
if not self.use_alibi_attn:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
if use_cuda_kernel:
|
||||
if not self.use_alibi_attn:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
else:
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
)
|
||||
else:
|
||||
if not is_verifier and not self.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
else:
|
||||
if not self.use_alibi_attn:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
copy_k_to_blocked_cache(
|
||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE This will cause difference as out length increases.
|
||||
class NopadBaichuanMLP(nn.Module):
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import (
|
||||
NopadBaichuanAttention,
|
||||
NopadBaichuanMLP,
|
||||
baichuan_rmsnorm_forward,
|
||||
)
|
||||
from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
llama_causal_lm_forward,
|
||||
llama_decoder_layer_forward,
|
||||
llama_model_forward,
|
||||
llama_rmsnorm_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
|
@ -21,26 +24,30 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
policy = super().module_policy()
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"lm_head.weight": Parameter(
|
||||
nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
|
||||
),
|
||||
"lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False),
|
||||
}
|
||||
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
policy["DecoderLayer"] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=NopadBaichuanMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadBaichuanAttention,
|
||||
),
|
||||
]
|
||||
)
|
||||
# used for relpacing Baichuan 7B/13B decoder layer
|
||||
for layer_name in ["DecoderLayer", "BaichuanLayer"]:
|
||||
policy[layer_name] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=NopadBaichuanMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadBaichuanAttention,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
|
||||
|
@ -48,11 +55,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
|
||||
description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
|
||||
)
|
||||
|
||||
return policy
|
||||
|
|
|
@ -185,6 +185,192 @@ def _fwd_context_paged_attention_kernel(
|
|||
return
|
||||
|
||||
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _alibi_fwd_context_paged_attention_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O,
|
||||
KCache,
|
||||
VCache,
|
||||
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
|
||||
batch_size,
|
||||
alibi_slopes,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_ot,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
context_lengths,
|
||||
sm_scale,
|
||||
KV_GROUPS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_head_idx = tl.program_id(1)
|
||||
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
|
||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||
|
||||
global_block_start_offest = block_start_m * BLOCK_M
|
||||
|
||||
# NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same
|
||||
tl.static_assert(BLOCK_M == BLOCK_N)
|
||||
tl.static_assert(BLOCK_N == BLOCK_SIZE)
|
||||
|
||||
# get the current sequence length from provided context lengths tensor
|
||||
cur_seq_len = tl.load(context_lengths + cur_seq_idx)
|
||||
# NOTE when talking to fused QKV and a nopadding context attention,
|
||||
# we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum`
|
||||
# could be considered as the start index of the current sequence.
|
||||
# FIXME might want to explore better way to get the summation of prev seq lengths.
|
||||
# `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton.
|
||||
prev_seq_len_sum = 0
|
||||
for i in range(0, cur_seq_idx):
|
||||
prev_seq_len_sum += tl.load(context_lengths + i)
|
||||
|
||||
offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
|
||||
offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + offset_q,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_qt, stride_qd),
|
||||
offsets=(global_block_start_offest, 0),
|
||||
block_shape=(BLOCK_M, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + offset_kv,
|
||||
shape=(HEAD_DIM, cur_seq_len),
|
||||
strides=(stride_kd, stride_kt),
|
||||
offsets=(0, 0),
|
||||
block_shape=(HEAD_DIM, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + offset_kv,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_vt, stride_vd),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=O + offset_q,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_ot, stride_od),
|
||||
offsets=(global_block_start_offest, 0),
|
||||
block_shape=(BLOCK_M, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
# block table for the current sequence
|
||||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||
# block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq)
|
||||
# Consider `block_start_m` as the logical block idx in the current block table,
|
||||
# as we have BLOCK_M the same size as the block size.
|
||||
cur_block_table_idx = block_start_m
|
||||
cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb)
|
||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||
|
||||
offsets_m = global_block_start_offest + tl.arange(0, BLOCK_M)
|
||||
offsets_n = tl.arange(0, BLOCK_N)
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
# load alibi_slope
|
||||
alibi_slope = tl.load(alibi_slopes + cur_head_idx)
|
||||
m_alibi_offset = tl.arange(0, BLOCK_M)[:, None] + global_block_start_offest
|
||||
n_alibi_offset = tl.arange(0, BLOCK_N)[None, :]
|
||||
|
||||
if global_block_start_offest >= cur_seq_len:
|
||||
return
|
||||
|
||||
Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0))
|
||||
|
||||
for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
block_start_n = tl.multiple_of(block_start_n, BLOCK_N)
|
||||
|
||||
k = tl.load(K_block_ptr, boundary_check=(0, 1))
|
||||
S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
S_ij += tl.dot(Q_i, k)
|
||||
S_ij *= sm_scale
|
||||
S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf"))
|
||||
|
||||
alibi = (n_alibi_offset + block_start_n - m_alibi_offset) * alibi_slope
|
||||
alibi = tl.where((alibi <= 0) & (m_alibi_offset < cur_seq_len), alibi, float("-inf"))
|
||||
S_ij += alibi
|
||||
|
||||
m_ij = tl.max(S_ij, 1) # rowmax(Sij)
|
||||
m_ij = tl.maximum(m_i, m_ij) # m_ij
|
||||
S_ij -= m_ij[:, None]
|
||||
p_ij_hat = tl.exp(S_ij)
|
||||
scale = tl.exp(m_i - m_ij)
|
||||
l_ij = scale * l_i + tl.sum(p_ij_hat, 1)
|
||||
acc = acc * scale[:, None]
|
||||
|
||||
v = tl.load(V_block_ptr, boundary_check=(1, 0))
|
||||
p_ij_hat = p_ij_hat.to(v.type.element_ty)
|
||||
|
||||
acc += tl.dot(p_ij_hat, v)
|
||||
l_i = l_ij
|
||||
m_i = m_ij
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0))
|
||||
|
||||
if cur_head_idx % KV_GROUPS == 0:
|
||||
# Copy k to corresponding cache block
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kt = global_block_start_offest + tl.arange(0, BLOCK_M)
|
||||
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
|
||||
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
|
||||
offsets_kcache = (
|
||||
KCache
|
||||
+ offset_kvcache
|
||||
+ offsets_dmodel[None, :] * stride_cached
|
||||
+ offsets_kcachebs[:, None] * stride_cachebs
|
||||
)
|
||||
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
# Copy v to corresponding cache block
|
||||
offsets_vd = offsets_dmodel
|
||||
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
|
||||
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
|
||||
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
|
||||
offsets_vcache = (
|
||||
VCache
|
||||
+ offset_kvcache
|
||||
+ offsets_vcachebs[None, :] * stride_cachebs
|
||||
+ offsets_dmodel[:, None] * stride_cached
|
||||
)
|
||||
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def context_attention_unpadded(
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_dim]
|
||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
|
||||
|
@ -195,6 +381,7 @@ def context_attention_unpadded(
|
|||
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
|
||||
block_size: int,
|
||||
output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
|
||||
alibi_slopes: torch.Tensor = None, # [num_heads]
|
||||
max_seq_len: int = None,
|
||||
sm_scale: int = None,
|
||||
):
|
||||
|
@ -226,40 +413,78 @@ def context_attention_unpadded(
|
|||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
|
||||
|
||||
_fwd_context_paged_attention_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
num_seqs,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
output.stride(0),
|
||||
head_dim,
|
||||
1,
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
context_lengths,
|
||||
sm_scale,
|
||||
num_kv_group,
|
||||
block_size,
|
||||
HEAD_DIM=Lk,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
if alibi_slopes is not None:
|
||||
_alibi_fwd_context_paged_attention_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
num_seqs,
|
||||
alibi_slopes,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
output.stride(0),
|
||||
head_dim,
|
||||
1,
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
context_lengths,
|
||||
sm_scale,
|
||||
num_kv_group,
|
||||
block_size,
|
||||
HEAD_DIM=Lk,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
else:
|
||||
_fwd_context_paged_attention_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
num_seqs,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
output.stride(0),
|
||||
head_dim,
|
||||
1,
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
context_lengths,
|
||||
sm_scale,
|
||||
num_kv_group,
|
||||
block_size,
|
||||
HEAD_DIM=Lk,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
|
||||
return output
|
||||
|
|
|
@ -124,6 +124,129 @@ def _flash_decoding_fwd_kernel(
|
|||
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
||||
|
||||
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _alibi_flash_decoding_fwd_kernel(
|
||||
Q, # [batch_size * q_len, head_num, head_dim]
|
||||
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
|
||||
kv_seq_len, # [batch_size]
|
||||
q_len,
|
||||
batch_size,
|
||||
alibi_slopes,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
stride_mid_ot,
|
||||
stride_mid_oh,
|
||||
stride_mid_ob,
|
||||
stride_mid_od,
|
||||
stride_mid_o_lset,
|
||||
stride_mid_o_lseh,
|
||||
stride_mid_o_lseb,
|
||||
sm_scale,
|
||||
KV_GROUPS: tl.constexpr,
|
||||
BLOCK_KV: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
cur_token_idx = tl.program_id(0)
|
||||
cur_seq_idx = cur_token_idx // q_len
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_token_off = (cur_token_idx % q_len) - q_len + 1
|
||||
cur_head_idx = tl.program_id(1)
|
||||
block_start_kv = tl.program_id(2) # for splitting k/v
|
||||
|
||||
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
|
||||
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
|
||||
# and then support calculating multiple kv cache blocks on an instance
|
||||
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
||||
# get the current (kv) sequence length
|
||||
# cur_token_off is used as a "mask" here for spec-dec during verification process
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||
return
|
||||
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||
q = tl.load(Q + offsets_q)
|
||||
# block table for the current sequence
|
||||
block_table_ptr = block_tables + cur_seq_idx * stride_bts
|
||||
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
|
||||
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
|
||||
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
|
||||
cur_occupied_size = tl.where(
|
||||
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
|
||||
)
|
||||
tl.device_assert(cur_occupied_size >= 0)
|
||||
|
||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=KCache + offset_kvcache,
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=VCache + offset_kvcache,
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
)
|
||||
k_cur_block = tl.load(K_block_ptr)
|
||||
v_cur_block = tl.load(V_block_ptr)
|
||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||
# use block size of the paged/blocked kv cache
|
||||
S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(alibi_slopes + cur_head_idx)
|
||||
position_k_offset = block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
|
||||
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
|
||||
# Refer to https://github.com/openai/triton/discussions/895
|
||||
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
|
||||
S_ij *= sm_scale
|
||||
S_ij -= alibi_slope * (cur_kv_seq_len - 1 - position_k_offset)
|
||||
S_ij = tl.where(cur_kv_seq_len > position_k_offset, S_ij, float("-inf"))
|
||||
|
||||
m = tl.max(S_ij, 0)
|
||||
S_ij -= m
|
||||
p_ij_hat = tl.exp(S_ij)
|
||||
l = tl.sum(p_ij_hat, 0)
|
||||
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
||||
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
|
||||
acc = acc / l
|
||||
|
||||
offsets_mid_o = (
|
||||
cur_token_idx * stride_mid_ot
|
||||
+ cur_head_idx * stride_mid_oh
|
||||
+ block_start_kv * stride_mid_ob
|
||||
+ offsets_dmodel * stride_mid_od
|
||||
)
|
||||
tl.store(mid_o + offsets_mid_o, acc)
|
||||
offsets_mid_o_lse = (
|
||||
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
||||
)
|
||||
# logsumexp L^(j) = m^(j) + log(l^(j))
|
||||
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
||||
|
||||
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _flash_decoding_fwd_reduce_kernel(
|
||||
|
@ -197,9 +320,10 @@ def flash_decoding_attention(
|
|||
output: torch.Tensor = None,
|
||||
mid_output: torch.Tensor = None,
|
||||
mid_output_lse: torch.Tensor = None,
|
||||
alibi_slopes: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
kv_group_num: int = 1,
|
||||
q_len: int = 1,
|
||||
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
|
||||
):
|
||||
"""
|
||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||
|
@ -220,6 +344,7 @@ def flash_decoding_attention(
|
|||
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
|
||||
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
|
||||
q_len > 1 only for verification process in speculative-decoding.
|
||||
alibi_slopes (torch.Tensor): [num_heads] alibi slopes used for alibi flash decoding.
|
||||
block_size (int): Size of each block in the blocked key/value cache.
|
||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
|
||||
|
@ -280,38 +405,74 @@ def flash_decoding_attention(
|
|||
num_heads,
|
||||
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
|
||||
)
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
q_len,
|
||||
bsz,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
_alibi_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
q_len,
|
||||
bsz,
|
||||
alibi_slopes,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
else:
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
q_len,
|
||||
bsz,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
grid = (triton.next_power_of_2(bsz * q_len), num_heads)
|
||||
_flash_decoding_fwd_reduce_kernel[grid](
|
||||
|
|
|
@ -12,7 +12,8 @@ from colossalai.inference.core.engine import InferenceEngine
|
|||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
|
||||
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "/home/data/models/Baichuan2-13B-Base"
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
|
@ -22,12 +23,10 @@ def setup_seed(seed):
|
|||
random.seed(seed)
|
||||
|
||||
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
BAICHUAN_MODEL_NAME_OR_PATH, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
|
||||
).cuda()
|
||||
model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
|
@ -35,17 +34,24 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||
]
|
||||
|
||||
output_len = 38
|
||||
do_sample = False
|
||||
do_sample = do_sample
|
||||
|
||||
if do_sample:
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
else:
|
||||
top_p = None
|
||||
top_k = None
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(
|
||||
max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True
|
||||
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel
|
||||
)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample)
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
|
@ -57,6 +63,8 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_new_tokens=output_len,
|
||||
)
|
||||
|
@ -67,9 +75,15 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||
|
||||
|
||||
@parameterize("prompt_template", [None, "baichuan"])
|
||||
def check_output_consistency(prompt_template):
|
||||
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
|
||||
@parameterize("do_sample", [True, False])
|
||||
@parameterize("use_cuda_kernel", [True, False])
|
||||
def check_output_consistency(prompt_template, do_sample, use_cuda_kernel):
|
||||
cai_outputs = check_inference_engine(
|
||||
use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
|
||||
)
|
||||
transformer_outputs = check_inference_engine(
|
||||
use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template
|
||||
)
|
||||
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||
|
|
|
@ -64,10 +64,6 @@ def torch_attn_ref(
|
|||
|
||||
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_scores = attn_scores + attention_mask
|
||||
|
||||
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
|
||||
|
@ -19,8 +20,31 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|||
HEAD_DIM = 32
|
||||
|
||||
|
||||
def _fill_with_neg_inf(t):
|
||||
return t.float().fill_(float("-inf")).type_as(t)
|
||||
|
||||
|
||||
# alibi mask calculation adapted from https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py
|
||||
def generate_alibi_mask(slopes, num_heads, max_seq_len, device):
|
||||
token_position = torch.arange(max_seq_len, device=device) - max_seq_len + 1
|
||||
token_position = token_position.unsqueeze(0).unsqueeze(0).expand(num_heads, -1, -1)
|
||||
diag = torch.diag(token_position[0])
|
||||
token_position = token_position - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * token_position
|
||||
alibi = alibi.view(num_heads, 1, max_seq_len)
|
||||
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_seq_len, max_seq_len], device=device)), 1)
|
||||
alibi_mask = alibi_mask.unsqueeze(0) + alibi
|
||||
return alibi_mask
|
||||
|
||||
|
||||
def torch_attn_unpad(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor, num_heads: int, num_kv_heads: int
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
context_lengths: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
slopes: torch.Tensor = None,
|
||||
):
|
||||
# Process sequence one by one and concatenate them together.
|
||||
# q,k,v [num_tokens(sum(context_lengths)), num_heads, head_dim]
|
||||
|
@ -35,6 +59,10 @@ def torch_attn_unpad(
|
|||
mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device)
|
||||
mask[mask == 0.0] = float("-inf")
|
||||
|
||||
if slopes != None:
|
||||
alibi_mask = generate_alibi_mask(slopes, num_heads, seq_len, q.device)
|
||||
mask = mask + alibi_mask
|
||||
|
||||
torch_attn_ref_out = torch_attn_ref(
|
||||
q[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
|
||||
k[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
|
||||
|
@ -60,6 +88,7 @@ def torch_attn_unpad(
|
|||
@pytest.mark.parametrize("num_attn_heads", [16])
|
||||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
def test_context_attention(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
|
@ -67,6 +96,7 @@ def test_context_attention(
|
|||
num_attn_heads: int,
|
||||
kv_group_num: int,
|
||||
same_context_len: bool,
|
||||
use_alibi_slopes: bool,
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
# It's necessary to clear cache here.
|
||||
|
@ -79,6 +109,10 @@ def test_context_attention(
|
|||
max_seq_len = max_num_blocks_per_seq * block_size
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
alibi_slopes = None
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_slopes = get_alibi_slopes(num_attn_heads, device)
|
||||
|
||||
if same_context_len:
|
||||
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
|
@ -100,12 +134,19 @@ def test_context_attention(
|
|||
_, num_heads, head_dim = q_unpad.shape
|
||||
|
||||
out_triton = context_attention_unpadded(
|
||||
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
k_cache_triton,
|
||||
v_cache_triton,
|
||||
context_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
|
||||
out_triton = out_triton.view(-1, num_heads, head_dim)
|
||||
|
||||
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
|
||||
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads, alibi_slopes)
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-3)
|
||||
|
@ -114,4 +155,4 @@ def test_context_attention(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_context_attention(4, 32, 8, 16, 1, True)
|
||||
test_context_attention(4, 32, 8, 16, 1, True, True)
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
|
@ -10,6 +12,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
|
|||
generate_caches_and_block_tables_v2,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
@ -24,6 +27,13 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|||
HEAD_DIM = 128
|
||||
|
||||
|
||||
def numpy_allclose(x, y, rtol, atol):
|
||||
x_numpy = x.detach().cpu().numpy()
|
||||
y_numpy = y.detach().cpu().numpy()
|
||||
|
||||
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def prepare_data(
|
||||
bsz: int,
|
||||
num_attn_heads: int,
|
||||
|
@ -64,6 +74,7 @@ def prepare_data(
|
|||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("q_len", [1, 5])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
def test_flash_decoding(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
|
@ -72,6 +83,7 @@ def test_flash_decoding(
|
|||
kv_group_num: int,
|
||||
same_context_len: bool,
|
||||
q_len: int,
|
||||
use_alibi_slopes: bool,
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -83,6 +95,14 @@ def test_flash_decoding(
|
|||
max_seq_len = block_size * max_num_blocks_per_seq
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_slopes = get_alibi_slopes(num_attn_heads, device)
|
||||
# Currently, alibi flash decoding does not support q_len>1.
|
||||
q_len = 1
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
q, k_unpad, v_unpad, kv_lengths = prepare_data(
|
||||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device
|
||||
)
|
||||
|
@ -92,6 +112,17 @@ def test_flash_decoding(
|
|||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
||||
attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_mask = generate_alibi_mask(alibi_slopes, num_attn_heads, max_kv_len_in_b, q.device)
|
||||
attention_mask = attention_mask + alibi_mask
|
||||
|
||||
if q_len == 1:
|
||||
if len(attention_mask.size()) == 4:
|
||||
attention_mask = attention_mask[:, :, -1:, :]
|
||||
else:
|
||||
attention_mask = attention_mask[:, -1:, :]
|
||||
|
||||
out_torch = torch_attn_ref(
|
||||
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||
)
|
||||
|
@ -130,14 +161,21 @@ def test_flash_decoding(
|
|||
output,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
q_len=q_len,
|
||||
) # [bsz * q_len, num_heads, head_dim]
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||
|
||||
rtol = 1e-4
|
||||
# After the shape becomes larger, some data elements are too small, leading to excessively large relative errors.
|
||||
if bsz == 32 and use_alibi_slopes:
|
||||
rtol = 100
|
||||
|
||||
numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flash_decoding(16, 32, 32, 16, 1, True)
|
||||
test_flash_decoding(16, 32, 32, 16, 1, True, 1, True)
|
||||
|
|
Loading…
Reference in New Issue