mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Add CacheBlock and KV-Cache Manager (#5156)
* [Inference] Add KVCache Manager * function refactored * add test for KVCache Manager * add attr beam width * Revise alloc func in CacheManager * Fix docs and pytests * add tp slicing for head number * optimize shapes of tensors used as physical cache * Apply using InferenceConfig on KVCacheManager * rm duplicate config file * Optimize cache allocation: use contiguous cache * Fix config in pytest (and config)pull/5258/head
parent
fab9b931d9
commit
3de2e62299
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
"""The inference configuration.
|
||||
|
@ -24,8 +25,10 @@ class InferenceConfig:
|
|||
max_seq_len: Maximum length of input sentence.
|
||||
quant_mode: Quantization mode.
|
||||
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||
beam_width: The maximum beam width used to initialize KV Cache.
|
||||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||
"""
|
||||
|
||||
|
||||
model: Union[str, nn.Module]
|
||||
tokenizer: str = None
|
||||
tokenizer_mode: str = "auto"
|
||||
|
@ -34,21 +37,18 @@ class InferenceConfig:
|
|||
max_output_len: int = 256
|
||||
max_input_len: int = 256
|
||||
block_size: int = 16
|
||||
gpu_utilization_rate: float = 0.7
|
||||
dtype: Union[str, torch.dtype] = torch.float32
|
||||
tp_size: int = 1
|
||||
pp_size: int = 1
|
||||
max_seq_len: Optional[int] = None
|
||||
quant_mode: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
# TODO: beam search is not support for now
|
||||
beam_width: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self):
|
||||
if self.gpu_utilization_rate > 1.0:
|
||||
raise ValueError(
|
||||
f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}."
|
||||
)
|
||||
if self.tokenizer_mode not in ["auto", "slow"]:
|
||||
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .block_cache import CacheBlock
|
||||
from .kvcache_manager import KVCacheManager
|
||||
|
||||
__all__ = ["CacheBlock", "KVCacheManager"]
|
|
@ -0,0 +1,56 @@
|
|||
from typing import Any
|
||||
|
||||
|
||||
class CacheBlock:
|
||||
"""A simplified version of logical cache block used for Paged Attention."""
|
||||
|
||||
def __init__(self, block_id: int, block_size: int, elem_size: int, k_ptrs: Any = None, v_ptrs: Any = None):
|
||||
# Unique id of a cache block
|
||||
self.block_id = block_id
|
||||
|
||||
# size/capacity of the block in terms of the number of tokens it can hold
|
||||
self.block_size = block_size
|
||||
|
||||
# element size in bytes
|
||||
self.elem_size = elem_size
|
||||
|
||||
# For common cases, we track the relationships between logical and physical caches in KV Cache Manager,
|
||||
# Additionally, k, v pointers can be optionally used for tracking the physical cache by CacheBlock itself.
|
||||
self.k_ptrs = k_ptrs
|
||||
self.v_ptrs = v_ptrs
|
||||
|
||||
self.ref_count = 0
|
||||
# the number of slots that have been allocated (i.e. the number of tokens occupying the block)
|
||||
self.allocated_size = 0
|
||||
# the token ids whose KV Cache would be written to corresponding physical caches
|
||||
# TODO add logics to update token_ids
|
||||
self.token_ids = [None] * self.block_size
|
||||
|
||||
@property
|
||||
def available_space(self) -> int:
|
||||
# `allocated_size` is ensured to be less than or equal to `block_size`
|
||||
return self.block_size - self.allocated_size
|
||||
|
||||
def add_ref(self) -> None:
|
||||
self.ref_count += 1
|
||||
|
||||
def remove_ref(self) -> None:
|
||||
assert self.ref_count > 0, f"Block#{self.block_id} has no reference to remove."
|
||||
self.ref_count -= 1
|
||||
|
||||
def has_ref(self) -> bool:
|
||||
return self.ref_count > 0
|
||||
|
||||
def allocate(self, size: int) -> None:
|
||||
assert size <= self.available_space, f"Block#{self.block_id} has no available space to allocate."
|
||||
self.allocated_size += size
|
||||
|
||||
def is_empty(self):
|
||||
return self.allocated_size < 1
|
||||
|
||||
def clear(self) -> None:
|
||||
self.ref_count = 0
|
||||
self.allocated_size = 0
|
||||
|
||||
def __repr__(self):
|
||||
return f"CacheBlock#{self.block_id}(ref#{self.ref_count}, allocated#{self.allocated_size})"
|
|
@ -0,0 +1,297 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from colossalai.inference.core.config import InferenceConfig
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .block_cache import CacheBlock
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
|
||||
|
||||
def get_model_config_attr(config: PretrainedConfig, attr_name: str):
|
||||
if hasattr(config, attr_name):
|
||||
return getattr(config, attr_name)
|
||||
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
|
||||
return getattr(config, config.attribute_map[attr_name])
|
||||
raise AttributeError(f"{attr_name} is not found in config")
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
"""KVCacheManager manages both the logical cache blocks and physical KV cache (tensors).
|
||||
|
||||
NOTE: The KVCacheManager is designed to be interacted with indices of logical blocks.
|
||||
That is, it won't allocate and return a physical cache to the engine or scheduler;
|
||||
instead, it will mark the logical block as allocated and update the block id representing
|
||||
the physical cache to the caller. The physical cache is actually used and updated in kernels.
|
||||
|
||||
Example
|
||||
A block table of a single sequence before block allocation might be:
|
||||
| -1 | -1 | -1 | -1 | -1 | -1 |
|
||||
where the maximum blocks per sequence is 6
|
||||
The block table after block allocation might be:
|
||||
| 0 | 1 | 2 | -1 | -1 | -1 |
|
||||
Then the logical blocks with id 0, 1, and 2, are allocated for this sequence,
|
||||
and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer,
|
||||
corresponding to these blocks will be used to read/write KV Caches in kernels.
|
||||
|
||||
For a batch of sequences, the block tables after allocation might be:
|
||||
| 0 | 1 | 2 | -1 | -1 | -1 |
|
||||
| 3 | 4 | 5 | 6 | 7 | -1 |
|
||||
| 8 | 9 | 10 | 11 | -1 | -1 |
|
||||
| 12 | 13 | 14 | 15 | -1 | -1 |
|
||||
where 16 logical cache blocks are allocated and the same number of physical cache blocks will be used in kernels.
|
||||
|
||||
Currently, allocations and updates are done at granularity of a single sequence.
|
||||
That is, the block table should be a 1D tensor of shape [max_blocks_per_sequence].
|
||||
And it's possible to have a batch of sequences with different lengths of block tables.
|
||||
"""
|
||||
|
||||
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.device = get_current_device()
|
||||
|
||||
# Parallel settings
|
||||
self.tp_size = config.tp_size
|
||||
# Model settings
|
||||
self.dtype = config.dtype
|
||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
||||
# For now we focus on MHA only, TODO add handling for MQA and GQA
|
||||
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
|
||||
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
|
||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||
self.head_num //= self.tp_size
|
||||
self.beam_width = config.beam_width
|
||||
self.max_batch_size = config.max_batch_size
|
||||
self.max_input_length = config.max_input_len
|
||||
self.max_output_length = config.max_output_len
|
||||
# Cache block settings
|
||||
self.block_size = config.block_size
|
||||
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
||||
self.max_blocks_per_sequence = (
|
||||
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||
) // self.block_size
|
||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||
|
||||
# Physical cache allocation
|
||||
if verbose:
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
self._kv_caches = self._init_device_caches()
|
||||
self.total_physical_cache_size_in_bytes = (
|
||||
self.elem_size_in_bytes
|
||||
* self.num_layers
|
||||
* 2
|
||||
* self.num_blocks
|
||||
* self.block_size
|
||||
* self.head_num
|
||||
* self.head_size
|
||||
)
|
||||
# Logical cache blocks allocation
|
||||
self._available_blocks = self.num_blocks
|
||||
self._cache_blocks = tuple(self._init_logical_caches())
|
||||
# block availablity state 0->allocated, 1->free
|
||||
self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool)
|
||||
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
|
||||
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
|
||||
|
||||
def get_total_num_blocks(self) -> int:
|
||||
"""Get the total number of logical cache blocks."""
|
||||
return self.num_blocks
|
||||
|
||||
def get_num_available_blocks(self) -> int:
|
||||
"""Get the number of available cache blocks."""
|
||||
return self._available_blocks
|
||||
|
||||
def get_max_blocks_per_sequence(self) -> int:
|
||||
"""Get the maximum number of blocks that can be allocated for a single sequence."""
|
||||
# TODO Consider removing this function as we plan to implement "half-dynamic" batching in schduler/request handler,
|
||||
# which will make the max_blocks_per_sequence dynamic based on the prompt lengths of sequences
|
||||
# in the current batch.
|
||||
return self.max_blocks_per_sequence
|
||||
|
||||
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:
|
||||
"""Get the key and value pointers of physical caches (of specific layer) corresponding to a logical cache block."""
|
||||
block: CacheBlock = self._cache_blocks[block_id]
|
||||
return block.k_ptrs[layer_id], block.v_ptrs[layer_id]
|
||||
|
||||
def get_block_table_kv_ptrs(self, block_table: torch.Tensor, layer_id: int) -> Tuple[int, int]:
|
||||
"""Get the key and value pointers of physical caches (of specific layer) corresponding to logical cache blocks indicated by the block table."""
|
||||
k_ptrs = []
|
||||
v_ptrs = []
|
||||
for block_id in block_table:
|
||||
if block_id >= 0:
|
||||
block: CacheBlock = self._cache_blocks[block_id]
|
||||
k_ptrs.append(block.k_ptrs[layer_id])
|
||||
v_ptrs.append(block.v_ptrs[layer_id])
|
||||
return k_ptrs, v_ptrs
|
||||
|
||||
def allocate_context_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
|
||||
"""Allocate the logical cache blocks for a single sequence during prefill stage,
|
||||
and updates the provided block table with the allocated block ids.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
||||
context_len: The length of the processing sequnece.
|
||||
"""
|
||||
assert block_table.dim() == 1
|
||||
if not torch.all(block_table < 0):
|
||||
self.logger.error("Some slots on provided block table have been allocated.")
|
||||
blocks_required = (context_len + self.block_size - 1) // self.block_size
|
||||
if blocks_required > self._available_blocks:
|
||||
self.logger.warning(
|
||||
f"No enough blocks to allocate. Available blocks {self._available_blocks}; context length {context_len}."
|
||||
)
|
||||
return
|
||||
|
||||
# Try contiguous allocation
|
||||
torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:])
|
||||
torch.subtract(
|
||||
self._block_states_cum[blocks_required:],
|
||||
self._block_states_cum[:-blocks_required],
|
||||
out=self._block_finder[blocks_required - 1 :],
|
||||
)
|
||||
end_indexes = torch.nonzero(self._block_finder == blocks_required, as_tuple=False).view(-1)
|
||||
if end_indexes.numel() > 0:
|
||||
# contiguous cache exists
|
||||
end_idx = end_indexes[0].item() + 1 # open interval
|
||||
start_idx = end_idx - blocks_required # closed interval
|
||||
block_indexes = torch.arange(start_idx, end_idx, device=block_table.device)
|
||||
else:
|
||||
# non-contiguous cache
|
||||
available_block_indexes = torch.nonzero(self._block_states == 0).view(-1)
|
||||
block_indexes = available_block_indexes[:blocks_required]
|
||||
# Update block table
|
||||
block_table[:blocks_required] = block_indexes
|
||||
# Update cache blocks
|
||||
self._block_states[block_indexes] = 0
|
||||
self._available_blocks -= blocks_required
|
||||
for block_id in block_indexes.tolist():
|
||||
block: CacheBlock = self._cache_blocks[block_id]
|
||||
block.add_ref()
|
||||
if block_id == block_indexes[-1].item():
|
||||
self._allocate_on_block(
|
||||
block, block.block_size if context_len % block.block_size == 0 else context_len % block.block_size
|
||||
)
|
||||
else:
|
||||
self._allocate_on_block(block, block.block_size)
|
||||
|
||||
def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
|
||||
"""Allocate the logical cache block for a single sequence during decoding stage,
|
||||
and updates the provided block table if a new cache block is needed.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
||||
context_len: The length of the processing sequnece (already-allocated length).
|
||||
"""
|
||||
assert block_table.dim() == 1
|
||||
# The last allocated block may be either partially or fully occupied.
|
||||
# `alloc_local_block_idx` is the index of block to be allocated on provided block table.
|
||||
alloc_local_block_idx = context_len // self.block_size
|
||||
self.allocate_single_block(block_table, alloc_local_block_idx, 1)
|
||||
|
||||
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int, space_asked: int) -> int:
|
||||
"""Allocate space asked on a single block in the block table, specified by the provided position id,
|
||||
and updates the provided block table with the allocated block.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
||||
block_local_idx: The index of the block in the block table.
|
||||
space_asked: i.e. The number of tokens to be assigned space for.
|
||||
Returns:
|
||||
The remaining space required to be allocated (in other blocks).
|
||||
"""
|
||||
assert block_table.dim() == 1
|
||||
block_global_id = block_table[block_local_idx].item()
|
||||
if block_global_id < 0:
|
||||
# Allocate a new block if the current position is not assigned a block yet
|
||||
assert self._available_blocks > 0, "No available blocks to allocate."
|
||||
free_block_id = torch.nonzero(self._block_states == 1).view(-1)[0]
|
||||
block: CacheBlock = self._cache_blocks[free_block_id]
|
||||
block.add_ref()
|
||||
block_global_id = block.block_id
|
||||
self._available_blocks -= 1
|
||||
self._block_states[block_global_id] = 0
|
||||
block_table[block_local_idx] = block_global_id
|
||||
block: CacheBlock = self._cache_blocks[block_global_id]
|
||||
return self._allocate_on_block(block, space_asked)
|
||||
|
||||
def free_block_table(self, block_table: torch.Tensor) -> None:
|
||||
"""Free the logical cache blocks for **a single sequence**."""
|
||||
assert block_table.dim() == 1
|
||||
for i in range(block_table.numel()):
|
||||
global_block_id = block_table[i].item()
|
||||
if global_block_id < 0:
|
||||
return
|
||||
block: CacheBlock = self._cache_blocks[global_block_id]
|
||||
block.remove_ref()
|
||||
if not block.has_ref():
|
||||
block.allocated_size = 0
|
||||
self._available_blocks += 1
|
||||
self._block_states[global_block_id] = 1
|
||||
# reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine)
|
||||
block_table[i] = -1
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all the references and allocations on all the cache blocks."""
|
||||
for block in self._cache_blocks:
|
||||
block.clear()
|
||||
self._available_blocks = self.num_blocks
|
||||
self._block_states[:] = 1
|
||||
|
||||
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
|
||||
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
|
||||
|
||||
def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int:
|
||||
"""Allocate a specific size of space on a provided cache block.
|
||||
|
||||
Returns:
|
||||
The remaining space required to be allocated (in other blocks).
|
||||
"""
|
||||
assert block.available_space > 0, "No available space on block to allocate."
|
||||
space_to_allocate = min(block.available_space, space_asked)
|
||||
block.allocate(space_to_allocate)
|
||||
return space_asked - space_to_allocate
|
||||
|
||||
def _init_logical_caches(self):
|
||||
"""Initialize the logical cache blocks.
|
||||
|
||||
NOTE This function should be called only after the physical caches have been allocated.
|
||||
The data pointers of physical caches will be binded to each logical cache block.
|
||||
"""
|
||||
assert self._kv_caches is not None and len(self._kv_caches[0]) > 0
|
||||
blocks = []
|
||||
physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size
|
||||
k_ptrs = [
|
||||
self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
|
||||
]
|
||||
v_ptrs = [
|
||||
self._kv_caches[1][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
|
||||
]
|
||||
for i in range(self.num_blocks):
|
||||
k_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in k_ptrs]
|
||||
v_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in v_ptrs]
|
||||
cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs, v_ptrs)
|
||||
blocks.append(cache_block)
|
||||
return blocks
|
||||
|
||||
def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the physical cache on the device.
|
||||
|
||||
For each layer of the model, we allocate two tensors for key and value respectively,
|
||||
with shape of [num_blocks, num_kv_heads, head_size, block_size]
|
||||
"""
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
|
||||
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
|
||||
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
|
||||
v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
|
||||
return k_cache, v_cache
|
|
@ -0,0 +1,152 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
from colossalai.inference.core.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"elem_size": 2,
|
||||
"block_size": 4,
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_logical_blocks(test_config):
|
||||
block = CacheBlock(block_id=0, block_size=test_config["block_size"], elem_size=test_config["elem_size"])
|
||||
|
||||
assert block.is_empty()
|
||||
assert block.available_space == test_config["block_size"]
|
||||
assert not block.has_ref()
|
||||
block.add_ref()
|
||||
assert block.ref_count == 1
|
||||
assert block.has_ref()
|
||||
block.remove_ref()
|
||||
assert block.ref_count == 0
|
||||
block.allocate(1)
|
||||
assert block.allocated_size == 1
|
||||
block.allocate(test_config["block_size"] - 1)
|
||||
assert block.available_space < 1
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"hidden_size": 512,
|
||||
"num_attention_heads": 16,
|
||||
"num_layers": 2,
|
||||
"block_size": 8,
|
||||
"max_batch_size": 10,
|
||||
"max_input_len": 32,
|
||||
"max_output_len": 32,
|
||||
"dtype": torch.float32,
|
||||
"beam_width": 1,
|
||||
"tp_size": 1,
|
||||
},
|
||||
{
|
||||
"hidden_size": 128,
|
||||
"num_attention_heads": 4,
|
||||
"num_layers": 3,
|
||||
"block_size": 4,
|
||||
"max_batch_size": 4,
|
||||
"max_input_len": 64,
|
||||
"max_output_len": 32,
|
||||
"dtype": torch.float16,
|
||||
"beam_width": 3,
|
||||
"tp_size": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_cache_manager(test_config):
|
||||
disable_existing_loggers()
|
||||
|
||||
assert test_config["max_batch_size"] > 1
|
||||
|
||||
hidden_size = test_config.pop("hidden_size")
|
||||
num_layers = test_config.pop("num_layers")
|
||||
num_attention_heads = test_config.pop("num_attention_heads")
|
||||
head_size = hidden_size // num_attention_heads
|
||||
block_size = test_config["block_size"]
|
||||
max_batch_size = test_config["max_batch_size"]
|
||||
max_input_length = test_config["max_input_len"]
|
||||
max_output_length = test_config["max_output_len"]
|
||||
|
||||
inference_config = InferenceConfig(model="", **test_config)
|
||||
model_config = LlamaConfig(
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
cache_manager = KVCacheManager(inference_config, model_config)
|
||||
|
||||
num_blocks = cache_manager.get_total_num_blocks()
|
||||
assert num_blocks > 0
|
||||
assert len(cache_manager._cache_blocks) == num_blocks
|
||||
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
|
||||
assert len(key_caches) == num_layers
|
||||
expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size)
|
||||
assert key_caches[0].shape == expected_kv_shape
|
||||
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
|
||||
expected_kv_block_shape = expected_kv_shape[1:]
|
||||
assert k_cache_block0.shape == expected_kv_block_shape
|
||||
assert v_cache_block0.shape == expected_kv_block_shape
|
||||
|
||||
max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence()
|
||||
block_tables = torch.tensor(
|
||||
[[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32
|
||||
)
|
||||
context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)]
|
||||
cnt_blocks_used = 0
|
||||
# Mock Prefill
|
||||
for req_i in range(max_batch_size):
|
||||
cur_seq_len = context_lengths[req_i]
|
||||
cur_block_table = block_tables[req_i]
|
||||
cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len)
|
||||
last_allocated_idx = (cur_seq_len - 1) // block_size
|
||||
assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0)
|
||||
cnt_blocks_used += torch.sum(cur_block_table >= 0).item()
|
||||
assert cache_manager.get_num_available_blocks() == num_blocks - cnt_blocks_used
|
||||
|
||||
# Mock Decoding
|
||||
for req_i in range(max_batch_size):
|
||||
context_length = context_lengths[req_i]
|
||||
cur_output_length = random.randint(1, max_output_length)
|
||||
cur_block_table = block_tables[req_i]
|
||||
for _ in range(cur_output_length):
|
||||
cache_manager.allocate_token_from_block_table(cur_block_table, context_length)
|
||||
context_length += 1
|
||||
context_length -= 1
|
||||
last_allocated_idx = context_length // block_size
|
||||
space_allocated_on_last_block = context_length % block_size + 1
|
||||
assert space_allocated_on_last_block > 0
|
||||
block_id = cur_block_table[last_allocated_idx]
|
||||
block: CacheBlock = cache_manager._cache_blocks[block_id]
|
||||
assert block.allocated_size == space_allocated_on_last_block
|
||||
|
||||
# Randomly select a request and clear its cache
|
||||
req_i = random.randint(0, max_batch_size - 1)
|
||||
context_length = context_lengths[req_i]
|
||||
blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item()
|
||||
prev_available_blocks = cache_manager.get_num_available_blocks()
|
||||
cache_manager.free_block_table(block_tables[req_i])
|
||||
assert cache_manager.get_num_available_blocks() == blocks_used_by_req + prev_available_blocks
|
||||
|
||||
k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0)
|
||||
k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0)
|
||||
elem_size = torch.tensor([], dtype=test_config["dtype"]).element_size()
|
||||
expected_stride = block_size * num_attention_heads * head_size * elem_size
|
||||
assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride
|
||||
cache_manager.clear_all()
|
||||
assert cache_manager.get_num_available_blocks() == num_blocks
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_logical_blocks()
|
||||
test_cache_manager()
|
Loading…
Reference in New Issue