From 3de2e622995321b042d4a8cffcd61686cda4a58e Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Mon, 11 Dec 2023 10:56:18 +0800 Subject: [PATCH] [Inference] Add CacheBlock and KV-Cache Manager (#5156) * [Inference] Add KVCache Manager * function refactored * add test for KVCache Manager * add attr beam width * Revise alloc func in CacheManager * Fix docs and pytests * add tp slicing for head number * optimize shapes of tensors used as physical cache * Apply using InferenceConfig on KVCacheManager * rm duplicate config file * Optimize cache allocation: use contiguous cache * Fix config in pytest (and config) --- colossalai/inference/core/config.py | 14 +- colossalai/inference/kv_cache/__init__.py | 4 + colossalai/inference/kv_cache/block_cache.py | 56 ++++ .../inference/kv_cache/kvcache_manager.py | 297 ++++++++++++++++++ tests/test_infer/test_kvcache_manager.py | 152 +++++++++ 5 files changed, 516 insertions(+), 7 deletions(-) create mode 100644 colossalai/inference/kv_cache/__init__.py create mode 100644 colossalai/inference/kv_cache/block_cache.py create mode 100644 colossalai/inference/kv_cache/kvcache_manager.py create mode 100644 tests/test_infer/test_kvcache_manager.py diff --git a/colossalai/inference/core/config.py b/colossalai/inference/core/config.py index 6b44dd7af..43d0b2bb2 100644 --- a/colossalai/inference/core/config.py +++ b/colossalai/inference/core/config.py @@ -1,9 +1,10 @@ -from typing import Optional, Union from dataclasses import dataclass +from typing import Optional, Union import torch import torch.nn as nn + @dataclass class InferenceConfig: """The inference configuration. @@ -24,8 +25,10 @@ class InferenceConfig: max_seq_len: Maximum length of input sentence. quant_mode: Quantization mode. revision: The specific version(a branch, name, a commit id, or a tag name) of model to use. + beam_width: The maximum beam width used to initialize KV Cache. + During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. """ - + model: Union[str, nn.Module] tokenizer: str = None tokenizer_mode: str = "auto" @@ -34,21 +37,18 @@ class InferenceConfig: max_output_len: int = 256 max_input_len: int = 256 block_size: int = 16 - gpu_utilization_rate: float = 0.7 dtype: Union[str, torch.dtype] = torch.float32 tp_size: int = 1 pp_size: int = 1 max_seq_len: Optional[int] = None quant_mode: Optional[str] = None revision: Optional[str] = None + # TODO: beam search is not support for now + beam_width: int = 1 def __post_init__(self): self._verify_args() def _verify_args(self): - if self.gpu_utilization_rate > 1.0: - raise ValueError( - f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}." - ) if self.tokenizer_mode not in ["auto", "slow"]: raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}") diff --git a/colossalai/inference/kv_cache/__init__.py b/colossalai/inference/kv_cache/__init__.py new file mode 100644 index 000000000..c3beb5545 --- /dev/null +++ b/colossalai/inference/kv_cache/__init__.py @@ -0,0 +1,4 @@ +from .block_cache import CacheBlock +from .kvcache_manager import KVCacheManager + +__all__ = ["CacheBlock", "KVCacheManager"] diff --git a/colossalai/inference/kv_cache/block_cache.py b/colossalai/inference/kv_cache/block_cache.py new file mode 100644 index 000000000..c9a38e2d5 --- /dev/null +++ b/colossalai/inference/kv_cache/block_cache.py @@ -0,0 +1,56 @@ +from typing import Any + + +class CacheBlock: + """A simplified version of logical cache block used for Paged Attention.""" + + def __init__(self, block_id: int, block_size: int, elem_size: int, k_ptrs: Any = None, v_ptrs: Any = None): + # Unique id of a cache block + self.block_id = block_id + + # size/capacity of the block in terms of the number of tokens it can hold + self.block_size = block_size + + # element size in bytes + self.elem_size = elem_size + + # For common cases, we track the relationships between logical and physical caches in KV Cache Manager, + # Additionally, k, v pointers can be optionally used for tracking the physical cache by CacheBlock itself. + self.k_ptrs = k_ptrs + self.v_ptrs = v_ptrs + + self.ref_count = 0 + # the number of slots that have been allocated (i.e. the number of tokens occupying the block) + self.allocated_size = 0 + # the token ids whose KV Cache would be written to corresponding physical caches + # TODO add logics to update token_ids + self.token_ids = [None] * self.block_size + + @property + def available_space(self) -> int: + # `allocated_size` is ensured to be less thanĀ or equal to `block_size` + return self.block_size - self.allocated_size + + def add_ref(self) -> None: + self.ref_count += 1 + + def remove_ref(self) -> None: + assert self.ref_count > 0, f"Block#{self.block_id} has no reference to remove." + self.ref_count -= 1 + + def has_ref(self) -> bool: + return self.ref_count > 0 + + def allocate(self, size: int) -> None: + assert size <= self.available_space, f"Block#{self.block_id} has no available space to allocate." + self.allocated_size += size + + def is_empty(self): + return self.allocated_size < 1 + + def clear(self) -> None: + self.ref_count = 0 + self.allocated_size = 0 + + def __repr__(self): + return f"CacheBlock#{self.block_id}(ref#{self.ref_count}, allocated#{self.allocated_size})" diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py new file mode 100644 index 000000000..8bf7af61c --- /dev/null +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -0,0 +1,297 @@ +from typing import List, Tuple + +import torch +from transformers.configuration_utils import PretrainedConfig + +from colossalai.inference.core.config import InferenceConfig +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device + +from .block_cache import CacheBlock + +GIGABYTE = 1024**3 + + +def get_model_config_attr(config: PretrainedConfig, attr_name: str): + if hasattr(config, attr_name): + return getattr(config, attr_name) + elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]): + return getattr(config, config.attribute_map[attr_name]) + raise AttributeError(f"{attr_name} is not found in config") + + +class KVCacheManager: + """KVCacheManager manages both the logical cache blocks and physical KV cache (tensors). + + NOTE: The KVCacheManager is designed to be interacted with indices of logical blocks. + That is, it won't allocate and return a physical cache to the engine or scheduler; + instead, it will mark the logical block as allocated and update the block id representing + the physical cache to the caller. The physical cache is actually used and updated in kernels. + + Example + A block table of a single sequence before block allocation might be: + | -1 | -1 | -1 | -1 | -1 | -1 | + where the maximum blocks per sequence is 6 + The block table after block allocation might be: + | 0 | 1 | 2 | -1 | -1 | -1 | + Then the logical blocks with id 0, 1, and 2, are allocated for this sequence, + and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer, + corresponding to these blocks will be used to read/write KV Caches in kernels. + + For a batch of sequences, the block tables after allocation might be: + | 0 | 1 | 2 | -1 | -1 | -1 | + | 3 | 4 | 5 | 6 | 7 | -1 | + | 8 | 9 | 10 | 11 | -1 | -1 | + | 12 | 13 | 14 | 15 | -1 | -1 | + where 16 logical cache blocks are allocated and the same number of physical cache blocks will be used in kernels. + + Currently, allocations and updates are done at granularity of a single sequence. + That is, the block table should be a 1D tensor of shape [max_blocks_per_sequence]. + And it's possible to have a batch of sequences with different lengths of block tables. + """ + + def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None: + self.logger = get_dist_logger(__name__) + self.device = get_current_device() + + # Parallel settings + self.tp_size = config.tp_size + # Model settings + self.dtype = config.dtype + self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() + self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") + # For now we focus on MHA only, TODO add handling for MQA and GQA + self.head_num = get_model_config_attr(model_config, "num_attention_heads") + self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num + assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" + self.head_num //= self.tp_size + self.beam_width = config.beam_width + self.max_batch_size = config.max_batch_size + self.max_input_length = config.max_input_len + self.max_output_length = config.max_output_len + # Cache block settings + self.block_size = config.block_size + # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size + self.max_blocks_per_sequence = ( + self.max_input_length + self.max_output_length + self.block_size - 1 + ) // self.block_size + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width + + # Physical cache allocation + if verbose: + alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size) + self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") + self._kv_caches = self._init_device_caches() + self.total_physical_cache_size_in_bytes = ( + self.elem_size_in_bytes + * self.num_layers + * 2 + * self.num_blocks + * self.block_size + * self.head_num + * self.head_size + ) + # Logical cache blocks allocation + self._available_blocks = self.num_blocks + self._cache_blocks = tuple(self._init_logical_caches()) + # block availablity state 0->allocated, 1->free + self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool) + self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64) + self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64) + + def get_total_num_blocks(self) -> int: + """Get the total number of logical cache blocks.""" + return self.num_blocks + + def get_num_available_blocks(self) -> int: + """Get the number of available cache blocks.""" + return self._available_blocks + + def get_max_blocks_per_sequence(self) -> int: + """Get the maximum number of blocks that can be allocated for a single sequence.""" + # TODO Consider removing this function as we plan to implement "half-dynamic" batching in schduler/request handler, + # which will make the max_blocks_per_sequence dynamic based on the prompt lengths of sequences + # in the current batch. + return self.max_blocks_per_sequence + + def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]: + """Get the key and value pointers of physical caches (of specific layer) corresponding to a logical cache block.""" + block: CacheBlock = self._cache_blocks[block_id] + return block.k_ptrs[layer_id], block.v_ptrs[layer_id] + + def get_block_table_kv_ptrs(self, block_table: torch.Tensor, layer_id: int) -> Tuple[int, int]: + """Get the key and value pointers of physical caches (of specific layer) corresponding to logical cache blocks indicated by the block table.""" + k_ptrs = [] + v_ptrs = [] + for block_id in block_table: + if block_id >= 0: + block: CacheBlock = self._cache_blocks[block_id] + k_ptrs.append(block.k_ptrs[layer_id]) + v_ptrs.append(block.v_ptrs[layer_id]) + return k_ptrs, v_ptrs + + def allocate_context_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None: + """Allocate the logical cache blocks for a single sequence during prefill stage, + and updates the provided block table with the allocated block ids. + + Args: + block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. + context_len: The length of the processing sequnece. + """ + assert block_table.dim() == 1 + if not torch.all(block_table < 0): + self.logger.error("Some slots on provided block table have been allocated.") + blocks_required = (context_len + self.block_size - 1) // self.block_size + if blocks_required > self._available_blocks: + self.logger.warning( + f"No enough blocks to allocate. Available blocks {self._available_blocks}; context length {context_len}." + ) + return + + # Try contiguous allocation + torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:]) + torch.subtract( + self._block_states_cum[blocks_required:], + self._block_states_cum[:-blocks_required], + out=self._block_finder[blocks_required - 1 :], + ) + end_indexes = torch.nonzero(self._block_finder == blocks_required, as_tuple=False).view(-1) + if end_indexes.numel() > 0: + # contiguous cache exists + end_idx = end_indexes[0].item() + 1 # open interval + start_idx = end_idx - blocks_required # closed interval + block_indexes = torch.arange(start_idx, end_idx, device=block_table.device) + else: + # non-contiguous cache + available_block_indexes = torch.nonzero(self._block_states == 0).view(-1) + block_indexes = available_block_indexes[:blocks_required] + # Update block table + block_table[:blocks_required] = block_indexes + # Update cache blocks + self._block_states[block_indexes] = 0 + self._available_blocks -= blocks_required + for block_id in block_indexes.tolist(): + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + if block_id == block_indexes[-1].item(): + self._allocate_on_block( + block, block.block_size if context_len % block.block_size == 0 else context_len % block.block_size + ) + else: + self._allocate_on_block(block, block.block_size) + + def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None: + """Allocate the logical cache block for a single sequence during decoding stage, + and updates the provided block table if a new cache block is needed. + + Args: + block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. + context_len: The length of the processing sequnece (already-allocated length). + """ + assert block_table.dim() == 1 + # The last allocated block may be either partially or fully occupied. + # `alloc_local_block_idx` is the index of block to be allocated on provided block table. + alloc_local_block_idx = context_len // self.block_size + self.allocate_single_block(block_table, alloc_local_block_idx, 1) + + def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int, space_asked: int) -> int: + """Allocate space asked on a single block in the block table, specified by the provided position id, + and updates the provided block table with the allocated block. + + Args: + block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. + block_local_idx: The index of the block in the block table. + space_asked: i.e. The number of tokens to be assigned space for. + Returns: + The remaining space required to be allocated (in other blocks). + """ + assert block_table.dim() == 1 + block_global_id = block_table[block_local_idx].item() + if block_global_id < 0: + # Allocate a new block if the current position is not assigned a block yet + assert self._available_blocks > 0, "No available blocks to allocate." + free_block_id = torch.nonzero(self._block_states == 1).view(-1)[0] + block: CacheBlock = self._cache_blocks[free_block_id] + block.add_ref() + block_global_id = block.block_id + self._available_blocks -= 1 + self._block_states[block_global_id] = 0 + block_table[block_local_idx] = block_global_id + block: CacheBlock = self._cache_blocks[block_global_id] + return self._allocate_on_block(block, space_asked) + + def free_block_table(self, block_table: torch.Tensor) -> None: + """Free the logical cache blocks for **a single sequence**.""" + assert block_table.dim() == 1 + for i in range(block_table.numel()): + global_block_id = block_table[i].item() + if global_block_id < 0: + return + block: CacheBlock = self._cache_blocks[global_block_id] + block.remove_ref() + if not block.has_ref(): + block.allocated_size = 0 + self._available_blocks += 1 + self._block_states[global_block_id] = 1 + # reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine) + block_table[i] = -1 + + def clear_all(self) -> None: + """Clear all the references and allocations on all the cache blocks.""" + for block in self._cache_blocks: + block.clear() + self._available_blocks = self.num_blocks + self._block_states[:] = 1 + + def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the tensor corresponding to the cache block with the prompted id for a specific layer.""" + return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx] + + def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int: + """Allocate a specific size of space on a provided cache block. + + Returns: + The remaining space required to be allocated (in other blocks). + """ + assert block.available_space > 0, "No available space on block to allocate." + space_to_allocate = min(block.available_space, space_asked) + block.allocate(space_to_allocate) + return space_asked - space_to_allocate + + def _init_logical_caches(self): + """Initialize the logical cache blocks. + + NOTE This function should be called only after the physical caches have been allocated. + The data pointers of physical caches will be binded to each logical cache block. + """ + assert self._kv_caches is not None and len(self._kv_caches[0]) > 0 + blocks = [] + physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size + k_ptrs = [ + self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers) + ] + v_ptrs = [ + self._kv_caches[1][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers) + ] + for i in range(self.num_blocks): + k_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in k_ptrs] + v_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in v_ptrs] + cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs, v_ptrs) + blocks.append(cache_block) + return blocks + + def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Initialize the physical cache on the device. + + For each layer of the model, we allocate two tensors for key and value respectively, + with shape of [num_blocks, num_kv_heads, head_size, block_size] + """ + alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size) + # TODO: Explore the performance when using difference shapes with kernel-related optimizations + # e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x] + k_cache: List[torch.Tensor] = [] + v_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)) + v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)) + return k_cache, v_cache diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py new file mode 100644 index 000000000..ee37f3ce1 --- /dev/null +++ b/tests/test_infer/test_kvcache_manager.py @@ -0,0 +1,152 @@ +import random + +import torch +from transformers.models.llama import LlamaConfig + +from colossalai.inference.core.config import InferenceConfig +from colossalai.inference.kv_cache import CacheBlock, KVCacheManager +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize + + +@parameterize( + "test_config", + [ + { + "elem_size": 2, + "block_size": 4, + } + ], +) +def test_logical_blocks(test_config): + block = CacheBlock(block_id=0, block_size=test_config["block_size"], elem_size=test_config["elem_size"]) + + assert block.is_empty() + assert block.available_space == test_config["block_size"] + assert not block.has_ref() + block.add_ref() + assert block.ref_count == 1 + assert block.has_ref() + block.remove_ref() + assert block.ref_count == 0 + block.allocate(1) + assert block.allocated_size == 1 + block.allocate(test_config["block_size"] - 1) + assert block.available_space < 1 + + +@parameterize( + "test_config", + [ + { + "hidden_size": 512, + "num_attention_heads": 16, + "num_layers": 2, + "block_size": 8, + "max_batch_size": 10, + "max_input_len": 32, + "max_output_len": 32, + "dtype": torch.float32, + "beam_width": 1, + "tp_size": 1, + }, + { + "hidden_size": 128, + "num_attention_heads": 4, + "num_layers": 3, + "block_size": 4, + "max_batch_size": 4, + "max_input_len": 64, + "max_output_len": 32, + "dtype": torch.float16, + "beam_width": 3, + "tp_size": 1, + }, + ], +) +def test_cache_manager(test_config): + disable_existing_loggers() + + assert test_config["max_batch_size"] > 1 + + hidden_size = test_config.pop("hidden_size") + num_layers = test_config.pop("num_layers") + num_attention_heads = test_config.pop("num_attention_heads") + head_size = hidden_size // num_attention_heads + block_size = test_config["block_size"] + max_batch_size = test_config["max_batch_size"] + max_input_length = test_config["max_input_len"] + max_output_length = test_config["max_output_len"] + + inference_config = InferenceConfig(model="", **test_config) + model_config = LlamaConfig( + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_attention_heads, + ) + cache_manager = KVCacheManager(inference_config, model_config) + + num_blocks = cache_manager.get_total_num_blocks() + assert num_blocks > 0 + assert len(cache_manager._cache_blocks) == num_blocks + key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers + assert len(key_caches) == num_layers + expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size) + assert key_caches[0].shape == expected_kv_shape + k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0) + expected_kv_block_shape = expected_kv_shape[1:] + assert k_cache_block0.shape == expected_kv_block_shape + assert v_cache_block0.shape == expected_kv_block_shape + + max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence() + block_tables = torch.tensor( + [[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32 + ) + context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)] + cnt_blocks_used = 0 + # Mock Prefill + for req_i in range(max_batch_size): + cur_seq_len = context_lengths[req_i] + cur_block_table = block_tables[req_i] + cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len) + last_allocated_idx = (cur_seq_len - 1) // block_size + assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0) + cnt_blocks_used += torch.sum(cur_block_table >= 0).item() + assert cache_manager.get_num_available_blocks() == num_blocks - cnt_blocks_used + + # Mock Decoding + for req_i in range(max_batch_size): + context_length = context_lengths[req_i] + cur_output_length = random.randint(1, max_output_length) + cur_block_table = block_tables[req_i] + for _ in range(cur_output_length): + cache_manager.allocate_token_from_block_table(cur_block_table, context_length) + context_length += 1 + context_length -= 1 + last_allocated_idx = context_length // block_size + space_allocated_on_last_block = context_length % block_size + 1 + assert space_allocated_on_last_block > 0 + block_id = cur_block_table[last_allocated_idx] + block: CacheBlock = cache_manager._cache_blocks[block_id] + assert block.allocated_size == space_allocated_on_last_block + + # Randomly select a request and clear its cache + req_i = random.randint(0, max_batch_size - 1) + context_length = context_lengths[req_i] + blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item() + prev_available_blocks = cache_manager.get_num_available_blocks() + cache_manager.free_block_table(block_tables[req_i]) + assert cache_manager.get_num_available_blocks() == blocks_used_by_req + prev_available_blocks + + k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0) + k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0) + elem_size = torch.tensor([], dtype=test_config["dtype"]).element_size() + expected_stride = block_size * num_attention_heads * head_size * elem_size + assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride + cache_manager.clear_all() + assert cache_manager.get_num_available_blocks() == num_blocks + + +if __name__ == "__main__": + test_logical_blocks() + test_cache_manager()