[Inference] Add CacheBlock and KV-Cache Manager (#5156)

* [Inference] Add KVCache Manager

* function refactored

* add test for KVCache Manager

* add attr beam width

* Revise alloc func in CacheManager

* Fix docs and pytests

* add tp slicing for head number

* optimize shapes of tensors used as physical cache

* Apply using InferenceConfig on KVCacheManager

* rm duplicate config file

* Optimize cache allocation: use contiguous cache

* Fix config in pytest (and config)
pull/5258/head
Yuanheng Zhao 2023-12-11 10:56:18 +08:00 committed by FrankLeeeee
parent fab9b931d9
commit 3de2e62299
5 changed files with 516 additions and 7 deletions

View File

@ -1,9 +1,10 @@
from typing import Optional, Union
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
@dataclass
class InferenceConfig:
"""The inference configuration.
@ -24,8 +25,10 @@ class InferenceConfig:
max_seq_len: Maximum length of input sentence.
quant_mode: Quantization mode.
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
beam_width: The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
"""
model: Union[str, nn.Module]
tokenizer: str = None
tokenizer_mode: str = "auto"
@ -34,21 +37,18 @@ class InferenceConfig:
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16
gpu_utilization_rate: float = 0.7
dtype: Union[str, torch.dtype] = torch.float32
tp_size: int = 1
pp_size: int = 1
max_seq_len: Optional[int] = None
quant_mode: Optional[str] = None
revision: Optional[str] = None
# TODO: beam search is not support for now
beam_width: int = 1
def __post_init__(self):
self._verify_args()
def _verify_args(self):
if self.gpu_utilization_rate > 1.0:
raise ValueError(
f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}."
)
if self.tokenizer_mode not in ["auto", "slow"]:
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")

View File

@ -0,0 +1,4 @@
from .block_cache import CacheBlock
from .kvcache_manager import KVCacheManager
__all__ = ["CacheBlock", "KVCacheManager"]

View File

@ -0,0 +1,56 @@
from typing import Any
class CacheBlock:
"""A simplified version of logical cache block used for Paged Attention."""
def __init__(self, block_id: int, block_size: int, elem_size: int, k_ptrs: Any = None, v_ptrs: Any = None):
# Unique id of a cache block
self.block_id = block_id
# size/capacity of the block in terms of the number of tokens it can hold
self.block_size = block_size
# element size in bytes
self.elem_size = elem_size
# For common cases, we track the relationships between logical and physical caches in KV Cache Manager,
# Additionally, k, v pointers can be optionally used for tracking the physical cache by CacheBlock itself.
self.k_ptrs = k_ptrs
self.v_ptrs = v_ptrs
self.ref_count = 0
# the number of slots that have been allocated (i.e. the number of tokens occupying the block)
self.allocated_size = 0
# the token ids whose KV Cache would be written to corresponding physical caches
# TODO add logics to update token_ids
self.token_ids = [None] * self.block_size
@property
def available_space(self) -> int:
# `allocated_size` is ensured to be less than or equal to `block_size`
return self.block_size - self.allocated_size
def add_ref(self) -> None:
self.ref_count += 1
def remove_ref(self) -> None:
assert self.ref_count > 0, f"Block#{self.block_id} has no reference to remove."
self.ref_count -= 1
def has_ref(self) -> bool:
return self.ref_count > 0
def allocate(self, size: int) -> None:
assert size <= self.available_space, f"Block#{self.block_id} has no available space to allocate."
self.allocated_size += size
def is_empty(self):
return self.allocated_size < 1
def clear(self) -> None:
self.ref_count = 0
self.allocated_size = 0
def __repr__(self):
return f"CacheBlock#{self.block_id}(ref#{self.ref_count}, allocated#{self.allocated_size})"

View File

@ -0,0 +1,297 @@
from typing import List, Tuple
import torch
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.core.config import InferenceConfig
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .block_cache import CacheBlock
GIGABYTE = 1024**3
def get_model_config_attr(config: PretrainedConfig, attr_name: str):
if hasattr(config, attr_name):
return getattr(config, attr_name)
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
return getattr(config, config.attribute_map[attr_name])
raise AttributeError(f"{attr_name} is not found in config")
class KVCacheManager:
"""KVCacheManager manages both the logical cache blocks and physical KV cache (tensors).
NOTE: The KVCacheManager is designed to be interacted with indices of logical blocks.
That is, it won't allocate and return a physical cache to the engine or scheduler;
instead, it will mark the logical block as allocated and update the block id representing
the physical cache to the caller. The physical cache is actually used and updated in kernels.
Example
A block table of a single sequence before block allocation might be:
| -1 | -1 | -1 | -1 | -1 | -1 |
where the maximum blocks per sequence is 6
The block table after block allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 |
Then the logical blocks with id 0, 1, and 2, are allocated for this sequence,
and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer,
corresponding to these blocks will be used to read/write KV Caches in kernels.
For a batch of sequences, the block tables after allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 |
| 3 | 4 | 5 | 6 | 7 | -1 |
| 8 | 9 | 10 | 11 | -1 | -1 |
| 12 | 13 | 14 | 15 | -1 | -1 |
where 16 logical cache blocks are allocated and the same number of physical cache blocks will be used in kernels.
Currently, allocations and updates are done at granularity of a single sequence.
That is, the block table should be a 1D tensor of shape [max_blocks_per_sequence].
And it's possible to have a batch of sequences with different lengths of block tables.
"""
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
self.logger = get_dist_logger(__name__)
self.device = get_current_device()
# Parallel settings
self.tp_size = config.tp_size
# Model settings
self.dtype = config.dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size
self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation
if verbose:
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches()
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
* 2
* self.num_blocks
* self.block_size
* self.head_num
* self.head_size
)
# Logical cache blocks allocation
self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches())
# block availablity state 0->allocated, 1->free
self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool)
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
def get_total_num_blocks(self) -> int:
"""Get the total number of logical cache blocks."""
return self.num_blocks
def get_num_available_blocks(self) -> int:
"""Get the number of available cache blocks."""
return self._available_blocks
def get_max_blocks_per_sequence(self) -> int:
"""Get the maximum number of blocks that can be allocated for a single sequence."""
# TODO Consider removing this function as we plan to implement "half-dynamic" batching in schduler/request handler,
# which will make the max_blocks_per_sequence dynamic based on the prompt lengths of sequences
# in the current batch.
return self.max_blocks_per_sequence
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:
"""Get the key and value pointers of physical caches (of specific layer) corresponding to a logical cache block."""
block: CacheBlock = self._cache_blocks[block_id]
return block.k_ptrs[layer_id], block.v_ptrs[layer_id]
def get_block_table_kv_ptrs(self, block_table: torch.Tensor, layer_id: int) -> Tuple[int, int]:
"""Get the key and value pointers of physical caches (of specific layer) corresponding to logical cache blocks indicated by the block table."""
k_ptrs = []
v_ptrs = []
for block_id in block_table:
if block_id >= 0:
block: CacheBlock = self._cache_blocks[block_id]
k_ptrs.append(block.k_ptrs[layer_id])
v_ptrs.append(block.v_ptrs[layer_id])
return k_ptrs, v_ptrs
def allocate_context_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
"""Allocate the logical cache blocks for a single sequence during prefill stage,
and updates the provided block table with the allocated block ids.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece.
"""
assert block_table.dim() == 1
if not torch.all(block_table < 0):
self.logger.error("Some slots on provided block table have been allocated.")
blocks_required = (context_len + self.block_size - 1) // self.block_size
if blocks_required > self._available_blocks:
self.logger.warning(
f"No enough blocks to allocate. Available blocks {self._available_blocks}; context length {context_len}."
)
return
# Try contiguous allocation
torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:])
torch.subtract(
self._block_states_cum[blocks_required:],
self._block_states_cum[:-blocks_required],
out=self._block_finder[blocks_required - 1 :],
)
end_indexes = torch.nonzero(self._block_finder == blocks_required, as_tuple=False).view(-1)
if end_indexes.numel() > 0:
# contiguous cache exists
end_idx = end_indexes[0].item() + 1 # open interval
start_idx = end_idx - blocks_required # closed interval
block_indexes = torch.arange(start_idx, end_idx, device=block_table.device)
else:
# non-contiguous cache
available_block_indexes = torch.nonzero(self._block_states == 0).view(-1)
block_indexes = available_block_indexes[:blocks_required]
# Update block table
block_table[:blocks_required] = block_indexes
# Update cache blocks
self._block_states[block_indexes] = 0
self._available_blocks -= blocks_required
for block_id in block_indexes.tolist():
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
if block_id == block_indexes[-1].item():
self._allocate_on_block(
block, block.block_size if context_len % block.block_size == 0 else context_len % block.block_size
)
else:
self._allocate_on_block(block, block.block_size)
def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
"""Allocate the logical cache block for a single sequence during decoding stage,
and updates the provided block table if a new cache block is needed.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece (already-allocated length).
"""
assert block_table.dim() == 1
# The last allocated block may be either partially or fully occupied.
# `alloc_local_block_idx` is the index of block to be allocated on provided block table.
alloc_local_block_idx = context_len // self.block_size
self.allocate_single_block(block_table, alloc_local_block_idx, 1)
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int, space_asked: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id,
and updates the provided block table with the allocated block.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
block_local_idx: The index of the block in the block table.
space_asked: i.e. The number of tokens to be assigned space for.
Returns:
The remaining space required to be allocated (in other blocks).
"""
assert block_table.dim() == 1
block_global_id = block_table[block_local_idx].item()
if block_global_id < 0:
# Allocate a new block if the current position is not assigned a block yet
assert self._available_blocks > 0, "No available blocks to allocate."
free_block_id = torch.nonzero(self._block_states == 1).view(-1)[0]
block: CacheBlock = self._cache_blocks[free_block_id]
block.add_ref()
block_global_id = block.block_id
self._available_blocks -= 1
self._block_states[block_global_id] = 0
block_table[block_local_idx] = block_global_id
block: CacheBlock = self._cache_blocks[block_global_id]
return self._allocate_on_block(block, space_asked)
def free_block_table(self, block_table: torch.Tensor) -> None:
"""Free the logical cache blocks for **a single sequence**."""
assert block_table.dim() == 1
for i in range(block_table.numel()):
global_block_id = block_table[i].item()
if global_block_id < 0:
return
block: CacheBlock = self._cache_blocks[global_block_id]
block.remove_ref()
if not block.has_ref():
block.allocated_size = 0
self._available_blocks += 1
self._block_states[global_block_id] = 1
# reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine)
block_table[i] = -1
def clear_all(self) -> None:
"""Clear all the references and allocations on all the cache blocks."""
for block in self._cache_blocks:
block.clear()
self._available_blocks = self.num_blocks
self._block_states[:] = 1
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int:
"""Allocate a specific size of space on a provided cache block.
Returns:
The remaining space required to be allocated (in other blocks).
"""
assert block.available_space > 0, "No available space on block to allocate."
space_to_allocate = min(block.available_space, space_asked)
block.allocate(space_to_allocate)
return space_asked - space_to_allocate
def _init_logical_caches(self):
"""Initialize the logical cache blocks.
NOTE This function should be called only after the physical caches have been allocated.
The data pointers of physical caches will be binded to each logical cache block.
"""
assert self._kv_caches is not None and len(self._kv_caches[0]) > 0
blocks = []
physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size
k_ptrs = [
self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
]
v_ptrs = [
self._kv_caches[1][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
]
for i in range(self.num_blocks):
k_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in k_ptrs]
v_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in v_ptrs]
cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs, v_ptrs)
blocks.append(cache_block)
return blocks
def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device.
For each layer of the model, we allocate two tensors for key and value respectively,
with shape of [num_blocks, num_kv_heads, head_size, block_size]
"""
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
return k_cache, v_cache

View File

@ -0,0 +1,152 @@
import random
import torch
from transformers.models.llama import LlamaConfig
from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize
@parameterize(
"test_config",
[
{
"elem_size": 2,
"block_size": 4,
}
],
)
def test_logical_blocks(test_config):
block = CacheBlock(block_id=0, block_size=test_config["block_size"], elem_size=test_config["elem_size"])
assert block.is_empty()
assert block.available_space == test_config["block_size"]
assert not block.has_ref()
block.add_ref()
assert block.ref_count == 1
assert block.has_ref()
block.remove_ref()
assert block.ref_count == 0
block.allocate(1)
assert block.allocated_size == 1
block.allocate(test_config["block_size"] - 1)
assert block.available_space < 1
@parameterize(
"test_config",
[
{
"hidden_size": 512,
"num_attention_heads": 16,
"num_layers": 2,
"block_size": 8,
"max_batch_size": 10,
"max_input_len": 32,
"max_output_len": 32,
"dtype": torch.float32,
"beam_width": 1,
"tp_size": 1,
},
{
"hidden_size": 128,
"num_attention_heads": 4,
"num_layers": 3,
"block_size": 4,
"max_batch_size": 4,
"max_input_len": 64,
"max_output_len": 32,
"dtype": torch.float16,
"beam_width": 3,
"tp_size": 1,
},
],
)
def test_cache_manager(test_config):
disable_existing_loggers()
assert test_config["max_batch_size"] > 1
hidden_size = test_config.pop("hidden_size")
num_layers = test_config.pop("num_layers")
num_attention_heads = test_config.pop("num_attention_heads")
head_size = hidden_size // num_attention_heads
block_size = test_config["block_size"]
max_batch_size = test_config["max_batch_size"]
max_input_length = test_config["max_input_len"]
max_output_length = test_config["max_output_len"]
inference_config = InferenceConfig(model="", **test_config)
model_config = LlamaConfig(
hidden_size=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=num_attention_heads,
)
cache_manager = KVCacheManager(inference_config, model_config)
num_blocks = cache_manager.get_total_num_blocks()
assert num_blocks > 0
assert len(cache_manager._cache_blocks) == num_blocks
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
assert len(key_caches) == num_layers
expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size)
assert key_caches[0].shape == expected_kv_shape
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
expected_kv_block_shape = expected_kv_shape[1:]
assert k_cache_block0.shape == expected_kv_block_shape
assert v_cache_block0.shape == expected_kv_block_shape
max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence()
block_tables = torch.tensor(
[[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32
)
context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)]
cnt_blocks_used = 0
# Mock Prefill
for req_i in range(max_batch_size):
cur_seq_len = context_lengths[req_i]
cur_block_table = block_tables[req_i]
cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len)
last_allocated_idx = (cur_seq_len - 1) // block_size
assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0)
cnt_blocks_used += torch.sum(cur_block_table >= 0).item()
assert cache_manager.get_num_available_blocks() == num_blocks - cnt_blocks_used
# Mock Decoding
for req_i in range(max_batch_size):
context_length = context_lengths[req_i]
cur_output_length = random.randint(1, max_output_length)
cur_block_table = block_tables[req_i]
for _ in range(cur_output_length):
cache_manager.allocate_token_from_block_table(cur_block_table, context_length)
context_length += 1
context_length -= 1
last_allocated_idx = context_length // block_size
space_allocated_on_last_block = context_length % block_size + 1
assert space_allocated_on_last_block > 0
block_id = cur_block_table[last_allocated_idx]
block: CacheBlock = cache_manager._cache_blocks[block_id]
assert block.allocated_size == space_allocated_on_last_block
# Randomly select a request and clear its cache
req_i = random.randint(0, max_batch_size - 1)
context_length = context_lengths[req_i]
blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item()
prev_available_blocks = cache_manager.get_num_available_blocks()
cache_manager.free_block_table(block_tables[req_i])
assert cache_manager.get_num_available_blocks() == blocks_used_by_req + prev_available_blocks
k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0)
k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0)
elem_size = torch.tensor([], dtype=test_config["dtype"]).element_size()
expected_stride = block_size * num_attention_heads * head_size * elem_size
assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride
cache_manager.clear_all()
assert cache_manager.get_num_available_blocks() == num_blocks
if __name__ == "__main__":
test_logical_blocks()
test_cache_manager()