[inference]Optimize the usage of the mid tensors space in flash attn (#5304)

* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
pull/5326/head
yuehuayingxueluo 2024-01-26 14:00:10 +08:00 committed by GitHub
parent af8359c430
commit 4f28cb43c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 199 additions and 57 deletions

View File

@ -55,6 +55,7 @@ class InferenceConfig:
def __post_init__(self):
self._init_batch_size()
self._verify_config()
self._get_dtype()
def _init_batch_size(self):
"""
@ -84,6 +85,7 @@ class InferenceConfig:
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert self.dtype in [
"fp16",
"fp32",
@ -97,3 +99,11 @@ class InferenceConfig:
"gptq",
None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
def _get_dtype(self) -> None:
if self.dtype == "fp32" or self.dtype == torch.float32:
self.dtype = torch.float32
elif self.dtype == "fp16" or self.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16

View File

@ -51,17 +51,10 @@ class InferenceEngine:
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
self.dtype = inference_config.dtype
model = model.eval()
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
self.dtype = torch.float16
model.half()
else:
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
model.to(self.dtype)
if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()
@ -217,6 +210,7 @@ class InferenceEngine:
None,
block_table,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
)
self.request_handler.add_sequence(sequence)
@ -241,7 +235,6 @@ class InferenceEngine:
batch,
self.k_cahce,
self.v_cache,
padding_id=self.tokenizer.pad_token_id,
)
logits = logits[:, -1, :]

View File

@ -4,6 +4,7 @@ import torch
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
@ -69,20 +70,60 @@ class RequestHandler:
Args:
inference_config: Configuration for initialize and manage kv cache.
model_config: Configuration for model
dtype (torch.dtype): The data type for weights and activations.
"""
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.inference_config = inference_config
self._init_cache(model_config)
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
self.dtype = inference_config.dtype
self.max_batch_size = inference_config.max_batch_size
# initialize cache
self._init_cache(model_config)
# initialize batch
device = torch.cuda.current_device()
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = model_config.hidden_size // model_config.num_attention_heads
fd_inter_tensor = FDIntermTensors()
fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=model_config.num_attention_heads,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
device=device,
)
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)

View File

@ -58,12 +58,7 @@ class KVCacheManager:
# Parallel settings
self.tp_size = config.tp_size
# Model settings
if config.dtype == "fp32" or config.dtype == torch.float32:
self.dtype = torch.float32
elif config.dtype == "fp16" or config.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
self.dtype = config.dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import (
@ -50,7 +51,6 @@ def llama_causal_lm_forward(
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
padding_id: int = None,
):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
@ -58,7 +58,6 @@ def llama_causal_lm_forward(
batch=batch,
k_caches=k_caches,
v_caches=v_caches,
padding_id=padding_id,
)
logits = self.lm_head(hidden_states)
return logits
@ -70,11 +69,10 @@ def llama_model_forward(
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
padding_id: int = None,
):
input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor()
attention_mask = batch.get_attn_mask(padding_id)
attention_mask = batch.get_attn_mask()
if attention_mask is not None:
if HAS_TRITON:
@ -84,6 +82,7 @@ def llama_model_forward(
else:
sequence_lengths = batch.get_sequence_lengths()
batch_size, _ = input_ids.shape
kv_seq_len = sequence_lengths.max().item()
if attention_mask is not None:
@ -102,7 +101,22 @@ def llama_model_forward(
hidden_states = self.embed_tokens(input_ids)
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
# cos_sin = (cos, sin)
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
@ -116,6 +130,9 @@ def llama_model_forward(
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)
hidden_states = self.norm(hidden_states)
@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: int = None,
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)
hidden_states = residual + hidden_states
@ -178,6 +201,9 @@ def llama_attn_forward(
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@ -206,7 +232,17 @@ def llama_attn_forward(
if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
@ -214,7 +250,17 @@ def llama_attn_forward(
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
else:
@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""
if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:

View File

@ -5,6 +5,7 @@ from typing import Any, List, Tuple, Union
import torch
from ordered_set import OrderedSet
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@ -61,6 +62,7 @@ class Sequence:
sample_params (SampleParams): The sample_params of input sequence.
block_table (torch.Tensor): The index of input sequence in block_table.
eos_token_id (int): The eos token id for this inference process.
pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length.
"""
@ -71,6 +73,7 @@ class Sequence:
sample_params: Any # SampleParams needs to be imported later.
block_table: torch.Tensor
eos_token_id: int
pad_token_id: int
max_output_len: int = 256
def __post_init__(self):
@ -167,15 +170,23 @@ class BatchInfo:
Information to be passed and used for a batch of sequences.
"""
max_batch_size: int
kv_max_split_num: int
num_heads: int
head_dim: int
sequences_set: OrderedSet[Sequence] = None
is_prompts: bool = True
device: torch.device = None
dtype: torch.dtype = None
fd_inter_tensor: FDIntermTensors = None
def __post_init__(self):
if self.device is None:
self.device = torch.cuda.current_device()
if self.sequences_set is None:
self.sequences_set = OrderedSet()
if self.fd_inter_tensor is None:
self.fd_inter_tensor = FDIntermTensors()
def init_batch(self, seqs: List["Sequence"] = None):
"""
@ -185,8 +196,6 @@ class BatchInfo:
seqs (List["Sequence"]): List of input sequence.
"""
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
if seqs is not None:
if not isinstance(seqs, list):
seqs = [seqs]
@ -197,16 +206,30 @@ class BatchInfo:
self.sequences_set.add(seq)
def init_fd_tensors(self):
if not self.fd_inter_tensor.is_initialized:
self.fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=self.num_heads,
kv_max_split_num=self.kv_max_split_num,
head_dim=self.head_dim,
dtype=self.dtype,
device=self.device,
)
def get_block_table_tensor(self) -> None:
tesnor_list = []
block_table = None
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set:
block_table = seq.block_table
assert (
block_table is not None
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table)
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
block_table = torch.stack(tesnor_list)
return block_table
@ -218,7 +241,6 @@ class BatchInfo:
"""
if self.is_prompts:
self.sequences_set.clear()
else:
for seq in self.sequences_set:
seq.mark_aborted()
@ -312,14 +334,14 @@ class BatchInfo:
"""
Get bacth inputs for forward inference computation.
"""
input_list = []
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set:
if self.is_prompts:
if seq.output_len > 0:
print(seq.output_token_id)
seq_data = seq.input_token_id + seq.output_token_id
print(seq_data)
input_list.append(seq.input_token_id + seq.output_token_id)
else:
input_list.append(seq.input_token_id)
@ -328,7 +350,8 @@ class BatchInfo:
max_seq_len = max(len(sub_list) for sub_list in input_list)
return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int)
# We assume that all the padding_id in seq are the same at present.
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
"""
@ -336,6 +359,9 @@ class BatchInfo:
"""
input_list = []
input_len_list = []
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set:
if self.is_prompts:
input_list.extend(seq.input_token_id)
@ -353,16 +379,23 @@ class BatchInfo:
Get the input_len of each sentence in this batch.
"""
len_list = []
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set:
len_list.append(seq.sentence_len)
return torch.tensor(len_list, dtype=torch.int, device=self.device)
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
def get_attn_mask(self) -> torch.Tensor:
"""
Generate and return attention mask.
"""
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
past_values = []
# We assume that all the padding_id in seq are the same at present.
padding_id = self.sequences_set[0].pad_token_id
for seq in self.sequences_set:
past_values.append(seq.input_token_id + seq.output_token_id)
@ -378,7 +411,7 @@ class BatchInfo:
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
return [pad] * (max_len - len(x)) + x
def _make_tensor_with_pad(

View File

@ -10,7 +10,6 @@ except ImportError:
if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_attention
from .flash_decoding_utils import FDIntermTensors
from .fused_rotary_embedding import fused_rotary_embedding
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
@ -27,7 +26,6 @@ if HAS_TRITON:
"rms_layernorm",
"gptq_fused_linear_triton",
"rotary_embedding",
"FDIntermTensors",
"fused_rotary_embedding",
"get_xine_cache",
]

View File

@ -5,7 +5,6 @@
#
# Inspired and modified from Triton Tutorial - Fused Attention
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
from typing import Optional
import torch
import triton
@ -195,7 +194,9 @@ def context_attention_unpadded(
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
block_size: int,
max_seq_len_in_b: Optional[int] = None,
output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
max_seq_len: int = None,
sm_scale: int = None,
):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk == Lv
@ -210,10 +211,9 @@ def context_attention_unpadded(
num_kv_group = num_heads // num_kv_heads
num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
sm_scale = 1.0 / (Lq**0.5)
output = torch.zeros_like(q)
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
output = torch.zeros_like(q) if output is None else output
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
# the size of physical cache block (i.e. `block_size`)

View File

@ -195,6 +195,7 @@ def flash_decoding_attention(
block_tables: torch.Tensor,
block_size: int,
max_seq_len_in_batch: int = None,
output: torch.Tensor = None,
mid_output: torch.Tensor = None,
mid_output_lse: torch.Tensor = None,
sm_scale: int = None,
@ -211,6 +212,7 @@ def flash_decoding_attention(
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@ -292,7 +294,7 @@ def flash_decoding_attention(
HEAD_DIM=head_dim,
)
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
grid = (triton.next_power_of_2(bsz), num_heads)

View File

@ -91,7 +91,7 @@ def benchmark_inference(args):
config.pad_token_id = config.eos_token_id
model = transformers.LlamaForCausalLM(config).cuda()
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
if args.dtype == "fp16":
model = model.half()

View File

@ -23,11 +23,12 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
# benchmark llama2-7b one single GPU
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
done
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
done

View File

@ -17,6 +17,7 @@ def check_config_and_inference():
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
@ -28,6 +29,7 @@ def check_config_and_inference():
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
@ -39,6 +41,7 @@ def check_config_and_inference():
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
sequence.mark_running()
@ -51,7 +54,12 @@ def check_config_and_inference():
assert sequence.output_len == 0
assert sequence.check_finish() == False
batch = BatchInfo(is_prompts=False)
batch = BatchInfo(
max_batch_size=8,
kv_max_split_num=16,
num_heads=2,
head_dim=128,
)
batch.init_batch([sequence])
batch.add_seqs([sequence2, sequence3])
batch.add_seqs([sequence])

View File

@ -3,8 +3,7 @@ import random
import numpy as np
import pytest
import torch
import transformers
from transformers import AutoTokenizer, GenerationConfig
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
@ -22,8 +21,8 @@ def setup_seed(seed):
def check_inference_engine(test_cai=False):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
).cuda()
@ -81,4 +80,4 @@ def test_inference_engine():
if __name__ == "__main__":
test_inference_engine()
test_inference_engine()

View File

@ -20,6 +20,7 @@ def check_running_list():
input_token_id=[1, 2, 3],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=1,
)
@ -56,6 +57,7 @@ def check_request_handler():
input_token_id=[1, 2, 3, 4, 5],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=torch.tensor([-1, -1]),
)

View File

@ -91,6 +91,7 @@ def test_flash_decoding(
max_seq_len_in_b = kv_seq_lengths.max().item()
# The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
@ -106,6 +107,7 @@ def test_flash_decoding(
block_tables,
block_size,
max_seq_len_in_b,
output,
mid_output,
mid_output_lse,
sm_scale=sm_scale,
@ -184,6 +186,7 @@ def bench_kernel(
block_tables = block_tables.to(device=device)
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
@ -199,6 +202,7 @@ def bench_kernel(
block_tables,
block_size,
max_seq_len_in_b,
output,
mid_output,
mid_output_lse,
sm_scale=sm_scale,